mirror of
https://github.com/goauthentik/authentik
synced 2026-05-07 23:52:38 +02:00
Compare commits
178 Commits
command-pa
...
rust-serve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
04c066d8b0 | ||
|
|
f3341a4b83 | ||
|
|
27f652dcf3 | ||
|
|
dca2c2f536 | ||
|
|
5d426411dd | ||
|
|
35ec2ea930 | ||
|
|
b7c4d04c16 | ||
|
|
8ef1b945e8 | ||
|
|
7fab5b6e93 | ||
|
|
7468a7271c | ||
|
|
1a270f9c6e | ||
|
|
3ae126cd99 | ||
|
|
6db2fbc8aa | ||
|
|
32f6738a40 | ||
|
|
1ddc596362 | ||
|
|
1281371077 | ||
|
|
58508ebc4e | ||
|
|
aa614ad31c | ||
|
|
b9b1c7ccf6 | ||
|
|
f8209680fa | ||
|
|
2b2c6a3b9b | ||
|
|
62644a79fd | ||
|
|
c426c94a25 | ||
|
|
2e04738306 | ||
|
|
297e8db6eb | ||
|
|
5b9a30be4b | ||
|
|
457429f261 | ||
|
|
a0bac73c59 | ||
|
|
b82abaf230 | ||
|
|
c4b1e4bd44 | ||
|
|
5592c4769a | ||
|
|
f71f5b7278 | ||
|
|
d7159cfce2 | ||
|
|
30dc4e120b | ||
|
|
619023be75 | ||
|
|
de63473cd2 | ||
|
|
6aa50b962c | ||
|
|
f240ca1708 | ||
|
|
550da2005e | ||
|
|
8818a0b06c | ||
|
|
013190ddd0 | ||
|
|
6fb777ae5b | ||
|
|
41f13d8805 | ||
|
|
fc5f0e7dc5 | ||
|
|
9b9379ac8f | ||
|
|
c4b0825dad | ||
|
|
946ace14c1 | ||
|
|
6a9eb8e9c7 | ||
|
|
4f0d0e72d5 | ||
|
|
411648672e | ||
|
|
d5f6d30aeb | ||
|
|
1508ad0ab8 | ||
|
|
892e8fd856 | ||
|
|
d4b0ac7c14 | ||
|
|
fe4857abbb | ||
|
|
8b73872c0d | ||
|
|
d22597377a | ||
|
|
58d198d60a | ||
|
|
1de19546d7 | ||
|
|
8ad054ce65 | ||
|
|
df95fc89eb | ||
|
|
75898710f1 | ||
|
|
3a5a0c2e4f | ||
|
|
b806e14a00 | ||
|
|
c2d02cd807 | ||
|
|
1212402231 | ||
|
|
2927f414c5 | ||
|
|
5ba18fbd55 | ||
|
|
1b108e40d6 | ||
|
|
982ae7b261 | ||
|
|
294a656ad2 | ||
|
|
dab8bab916 | ||
|
|
ee1803a0ae | ||
|
|
99c9894a04 | ||
|
|
2352ce72c9 | ||
|
|
bb28e6425d | ||
|
|
f2149dfd90 | ||
|
|
2ff0f09db1 | ||
|
|
40a91fd4fb | ||
|
|
2e3f76441c | ||
|
|
f91474dd91 | ||
|
|
61dbd5976f | ||
|
|
8099ac6508 | ||
|
|
61ed26e3f6 | ||
|
|
ea17d4cbf1 | ||
|
|
ac388667d0 | ||
|
|
cdc42de5b5 | ||
|
|
2770c3a7e0 | ||
|
|
f41f501702 | ||
|
|
08685a574a | ||
|
|
15377f5154 | ||
|
|
52da505aab | ||
|
|
d8a2a069aa | ||
|
|
fec9dcc2e7 | ||
|
|
b644fa5a2c | ||
|
|
9a5d59533e | ||
|
|
3c64570398 | ||
|
|
a735f6dcf3 | ||
|
|
f33e7f13eb | ||
|
|
eee00fa29b | ||
|
|
5a95a14a8f | ||
|
|
7b46fac608 | ||
|
|
bb488e1c2c | ||
|
|
138aa0e4e9 | ||
|
|
e65cd2999f | ||
|
|
490790c272 | ||
|
|
b640b42dbb | ||
|
|
1371465ebe | ||
|
|
c623b96dc2 | ||
|
|
43fe1918db | ||
|
|
3e2489834d | ||
|
|
7ba86b7de3 | ||
|
|
85ef3cda04 | ||
|
|
62911536bf | ||
|
|
1a27971399 | ||
|
|
7a0e946bb5 | ||
|
|
428ccc2c14 | ||
|
|
0b706d5830 | ||
|
|
b9f4a1aed7 | ||
|
|
d2cb45aadf | ||
|
|
de12748f25 | ||
|
|
f8f39b8edc | ||
|
|
986385a951 | ||
|
|
129ed95cf0 | ||
|
|
dc0d535fcc | ||
|
|
5c0e23a78f | ||
|
|
b4bf082864 | ||
|
|
2f00983c29 | ||
|
|
af93a1e230 | ||
|
|
dbb3898621 | ||
|
|
a668ddcaf5 | ||
|
|
051aea6f99 | ||
|
|
b8104ec156 | ||
|
|
e59970e6ab | ||
|
|
0b50b0aa13 | ||
|
|
7b9b1c2c70 | ||
|
|
1e1cdffb33 | ||
|
|
8ad572ba35 | ||
|
|
8a5b8ad047 | ||
|
|
907a4ce478 | ||
|
|
a26254df02 | ||
|
|
bf9679dcb5 | ||
|
|
71ee2f6c66 | ||
|
|
90fb12a804 | ||
|
|
e271a8a0af | ||
|
|
6100fd7800 | ||
|
|
b78d62f550 | ||
|
|
21eb1bb7d0 | ||
|
|
e4445a44c4 | ||
|
|
6fecbb41ca | ||
|
|
4a840796bf | ||
|
|
cc7f190735 | ||
|
|
c4962f86dd | ||
|
|
ad672338e0 | ||
|
|
fadf344955 | ||
|
|
8c58873a3a | ||
|
|
ac7dd69be2 | ||
|
|
f01ab7ccb2 | ||
|
|
13f7ac6eca | ||
|
|
24202f9a3f | ||
|
|
5a72130576 | ||
|
|
fe5d24004e | ||
|
|
dd7c13c5bd | ||
|
|
32de1ab6c6 | ||
|
|
6e4384d672 | ||
|
|
79f7759d4b | ||
|
|
0ca41cb184 | ||
|
|
f8e5c895d6 | ||
|
|
2ba8991a3b | ||
|
|
19b36d2e0d | ||
|
|
fb802a53bc | ||
|
|
2f6465d5a0 | ||
|
|
c5437d2b0b | ||
|
|
8e2e90a87f | ||
|
|
4deb3d45cf | ||
|
|
b61bb3cc17 | ||
|
|
af3332df9f | ||
|
|
0849df7478 |
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[build]
|
||||
rustflags = ["--cfg", "tokio_unstable"]
|
||||
@@ -1,5 +1,17 @@
|
||||
[licenses]
|
||||
allow = ["Apache-2.0", "MIT", "MPL-2.0", "Unicode-3.0"]
|
||||
allow = [
|
||||
"Apache-2.0 WITH LLVM-exception",
|
||||
"Apache-2.0",
|
||||
"BSD-3-Clause",
|
||||
"CC0-1.0",
|
||||
"CDLA-Permissive-2.0",
|
||||
"ISC",
|
||||
"MIT",
|
||||
"MPL-2.0",
|
||||
"OpenSSL",
|
||||
"Unicode-3.0",
|
||||
"Zlib",
|
||||
]
|
||||
|
||||
[licenses.private]
|
||||
ignore = true
|
||||
8
.github/workflows/ci-main.yml
vendored
8
.github/workflows/ci-main.yml
vendored
@@ -144,6 +144,7 @@ jobs:
|
||||
CI_TEST_SEED: ${{ needs.test-make-seed.outputs.seed }}
|
||||
CI_RUN_ID: ${{ matrix.run_id }}
|
||||
CI_TOTAL_RUNS: "5"
|
||||
PROMETHEUS_MULTIPROC_DIR: /tmp
|
||||
run: |
|
||||
uv run make ci-test
|
||||
- uses: ./.github/actions/test-results
|
||||
@@ -173,6 +174,7 @@ jobs:
|
||||
CI_TEST_SEED: ${{ needs.test-make-seed.outputs.seed }}
|
||||
CI_RUN_ID: ${{ matrix.run_id }}
|
||||
CI_TOTAL_RUNS: "5"
|
||||
PROMETHEUS_MULTIPROC_DIR: /tmp
|
||||
run: |
|
||||
uv run make ci-test
|
||||
- uses: ./.github/actions/test-results
|
||||
@@ -189,6 +191,8 @@ jobs:
|
||||
- name: Create k8s Kind Cluster
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # v1.14.0
|
||||
- name: run integration
|
||||
env:
|
||||
PROMETHEUS_MULTIPROC_DIR: /tmp
|
||||
run: |
|
||||
uv run coverage run manage.py test tests/integration
|
||||
uv run coverage xml
|
||||
@@ -245,6 +249,8 @@ jobs:
|
||||
npm run build
|
||||
npm run build:sfe
|
||||
- name: run e2e
|
||||
env:
|
||||
PROMETHEUS_MULTIPROC_DIR: /tmp
|
||||
run: |
|
||||
uv run coverage run manage.py test ${{ matrix.job.glob }}
|
||||
uv run coverage xml
|
||||
@@ -288,6 +294,8 @@ jobs:
|
||||
npm run build
|
||||
npm run build:sfe
|
||||
- name: run conformance
|
||||
env:
|
||||
PROMETHEUS_MULTIPROC_DIR: /tmp
|
||||
run: |
|
||||
uv run coverage run manage.py test ${{ matrix.job.glob }}
|
||||
uv run coverage xml
|
||||
|
||||
4916
Cargo.lock
generated
4916
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
167
Cargo.toml
167
Cargo.toml
@@ -1,5 +1,5 @@
|
||||
[workspace]
|
||||
members = ["website/scripts/docsmg"]
|
||||
members = [".", "website/scripts/docsmg"]
|
||||
resolver = "3"
|
||||
|
||||
[workspace.package]
|
||||
@@ -12,11 +12,101 @@ license-file = "LICENSE"
|
||||
publish = false
|
||||
|
||||
[workspace.dependencies]
|
||||
arc-swap = "1.8.2"
|
||||
argh = "0.1.17"
|
||||
async-trait = "0.1.89"
|
||||
aws-lc-rs = { version = "1.16.1", features = ["fips"] }
|
||||
axum = { version = "0.8.8", features = ["http2", "macros", "ws"] }
|
||||
axum-server = { version = "0.8.0", features = ["tls-rustls-no-provider"] }
|
||||
bytes = "1.11.1"
|
||||
chrono = "0.4.44"
|
||||
clap = { version = "4.5.59", features = ["derive", "env"] }
|
||||
client-ip = { version = "0.2.1", features = ["forwarded-header"] }
|
||||
color-eyre = "0.6.5"
|
||||
colored = "3.1.1"
|
||||
config = { version = "0.15.19", default-features = false, features = [
|
||||
"yaml",
|
||||
"async",
|
||||
] }
|
||||
console-subscriber = "0.5.0"
|
||||
dotenvy = "0.15.7"
|
||||
durstr = "0.4.0"
|
||||
eyre = "0.6.12"
|
||||
forwarded-header-value = "0.1.1"
|
||||
futures = "0.3.32"
|
||||
glob = "0.3.3"
|
||||
http-body-util = "0.1.3"
|
||||
hyper = "1.8.1"
|
||||
hyper-unix-socket = "0.3.0"
|
||||
hyper-util = "0.1.20"
|
||||
ipnet = { version = "2.12.0", features = ["serde"] }
|
||||
# See https://github.com/mladedav/json-subscriber/pull/23
|
||||
json-subscriber = { git = "https://github.com/rissson/json-subscriber.git", rev = "950ad7cb887a0a14fd5cb8afb8e76db1f456c032" }
|
||||
jsonwebtoken = { version = "10.3.0", default-features = false, features = [
|
||||
"aws_lc_rs",
|
||||
] }
|
||||
metrics = "0.24.3"
|
||||
metrics-exporter-prometheus = { version = "0.18.1", default-features = false }
|
||||
nix = { version = "0.31.2", features = ["hostname", "signal"] }
|
||||
notify = "8.2.0"
|
||||
pem = "3.0.6"
|
||||
pin-project-lite = "0.2.17"
|
||||
pyo3 = "0.28.2"
|
||||
percent-encoding = "2.3.2"
|
||||
rcgen = { version = "0.14.7", default-features = false, features = [
|
||||
"aws_lc_rs",
|
||||
"fips",
|
||||
] }
|
||||
regex = "1.12.3"
|
||||
rustls = { version = "0.23.37", features = ["fips"] }
|
||||
sentry = { version = "0.47.0", default-features = false, features = [
|
||||
"backtrace",
|
||||
"contexts",
|
||||
"debug-images",
|
||||
"panic",
|
||||
"rustls",
|
||||
"reqwest",
|
||||
"tower",
|
||||
"tracing",
|
||||
] }
|
||||
serde = { version = "1.0.228", features = ["derive"] }
|
||||
serde_json = "1.0.149"
|
||||
sqlx = { version = "0.8.6", default-features = false, features = [
|
||||
"runtime-tokio",
|
||||
"tls-rustls-aws-lc-rs",
|
||||
"postgres",
|
||||
"derive",
|
||||
"macros",
|
||||
"uuid",
|
||||
"chrono",
|
||||
"ipnet",
|
||||
"json",
|
||||
] }
|
||||
time = "0.3.47"
|
||||
thiserror = "2.0.18"
|
||||
tokio = { version = "1.50.0", features = ["full"] }
|
||||
tokio-rustls = "0.26.4"
|
||||
tokio-tungstenite = "0.28.0"
|
||||
tokio-util = "0.7.18"
|
||||
tower = "0.5.3"
|
||||
tower-http = { version = "0.6.8", features = [
|
||||
"compression-br",
|
||||
"compression-deflate",
|
||||
"compression-gzip",
|
||||
"compression-zstd",
|
||||
"fs",
|
||||
"timeout",
|
||||
] }
|
||||
tower-service = "0.3.3"
|
||||
tracing = "0.1.44"
|
||||
tracing-error = "0.2.1"
|
||||
tracing-subscriber = { version = "0.3.22", features = [
|
||||
"env-filter",
|
||||
"json",
|
||||
"tracing-log",
|
||||
] }
|
||||
url = "2.5.8"
|
||||
uuid = { version = "1.22.0", features = ["v4"] }
|
||||
|
||||
[profile.dev.package.backtrace]
|
||||
opt-level = 3
|
||||
@@ -60,10 +150,14 @@ perf = { priority = -1, level = "warn" }
|
||||
style = { priority = -1, level = "warn" }
|
||||
suspicious = { priority = -1, level = "warn" }
|
||||
### and disable the ones we don't want
|
||||
### cargo group
|
||||
multiple_crate_versions = "allow"
|
||||
### pedantic group
|
||||
redundant_closure_for_method_calls = "allow"
|
||||
struct_field_names = "allow"
|
||||
too_many_lines = "allow"
|
||||
### nursery
|
||||
missing_const_for_fn = "allow"
|
||||
redundant_pub_crate = "allow"
|
||||
option_if_let_else = "allow"
|
||||
### restriction group
|
||||
@@ -78,7 +172,6 @@ create_dir = "warn"
|
||||
dbg_macro = "warn"
|
||||
default_numeric_fallback = "warn"
|
||||
disallowed_script_idents = "warn"
|
||||
doc_paragraphs_missing_punctuation = "warn"
|
||||
empty_drop = "warn"
|
||||
empty_enum_variants_with_brackets = "warn"
|
||||
empty_structs_with_brackets = "warn"
|
||||
@@ -131,3 +224,73 @@ unused_trait_names = "warn"
|
||||
unwrap_in_result = "warn"
|
||||
unwrap_used = "warn"
|
||||
verbose_file_reads = "warn"
|
||||
|
||||
[package]
|
||||
name = "authentik"
|
||||
version = "2026.5.0-rc1"
|
||||
authors.workspace = true
|
||||
edition.workspace = true
|
||||
readme.workspace = true
|
||||
homepage.workspace = true
|
||||
repository.workspace = true
|
||||
license-file.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["core", "proxy"]
|
||||
proxy = []
|
||||
core = ["proxy", "dep:sqlx", "dep:pyo3"]
|
||||
|
||||
[dependencies]
|
||||
arc-swap.workspace = true
|
||||
argh.workspace = true
|
||||
async-trait.workspace = true
|
||||
aws-lc-rs.workspace = true
|
||||
axum-server.workspace = true
|
||||
axum.workspace = true
|
||||
client-ip.workspace = true
|
||||
color-eyre.workspace = true
|
||||
config.workspace = true
|
||||
console-subscriber.workspace = true
|
||||
durstr.workspace = true
|
||||
eyre.workspace = true
|
||||
forwarded-header-value.workspace = true
|
||||
futures.workspace = true
|
||||
glob.workspace = true
|
||||
http-body-util.workspace = true
|
||||
hyper-unix-socket.workspace = true
|
||||
hyper-util.workspace = true
|
||||
hyper.workspace = true
|
||||
ipnet.workspace = true
|
||||
json-subscriber.workspace = true
|
||||
jsonwebtoken.workspace = true
|
||||
metrics.workspace = true
|
||||
metrics-exporter-prometheus.workspace = true
|
||||
nix.workspace = true
|
||||
notify.workspace = true
|
||||
pem.workspace = true
|
||||
percent-encoding.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
pyo3 = { workspace = true, optional = true }
|
||||
rcgen.workspace = true
|
||||
rustls.workspace = true
|
||||
sentry.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sqlx = { workspace = true, optional = true }
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tokio-rustls.workspace = true
|
||||
tokio-tungstenite.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio.workspace = true
|
||||
tower-http.workspace = true
|
||||
tower.workspace = true
|
||||
tracing-error.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
tracing.workspace = true
|
||||
url.workspace = true
|
||||
uuid.workspace = true
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
24
Makefile
24
Makefile
@@ -84,7 +84,7 @@ test: ## Run the server tests and produce a coverage report (locally)
|
||||
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
||||
$(UV) run black $(PY_SOURCES)
|
||||
$(UV) run ruff check --fix $(PY_SOURCES)
|
||||
$(CARGO) +nightly fmt --all -- --config-path .config/rustfmt.toml
|
||||
$(CARGO) +nightly fmt --all -- --config-path .cargo/rustfmt.toml
|
||||
|
||||
lint-spellcheck: ## Reports spelling errors.
|
||||
npm run lint:spellcheck
|
||||
@@ -110,12 +110,24 @@ i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that requir
|
||||
aws-cfn:
|
||||
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
|
||||
|
||||
run-server: ## Run the main authentik server process
|
||||
run: ## Run the authentik server and worker, without auto reloading
|
||||
$(UV) run ak allinone
|
||||
|
||||
run-watch: ## Run the authentik server and worker, with auto reloading
|
||||
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak allinone
|
||||
|
||||
run-server: ## Run the authentik server, without auto reloading
|
||||
$(UV) run ak server
|
||||
|
||||
run-worker: ## Run the main authentik worker process
|
||||
run-server-watch: ## Run the authentik server, with auto reloading
|
||||
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak server
|
||||
|
||||
run-worker: ## Run the authentik worker, without auto reloading
|
||||
$(UV) run ak worker
|
||||
|
||||
run-worker-watch: ## Run the authentik worker, with auto reloading
|
||||
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak worker
|
||||
|
||||
core-i18n-extract:
|
||||
$(UV) run ak makemessages \
|
||||
--add-location file \
|
||||
@@ -154,7 +166,7 @@ ifndef version
|
||||
$(error Usage: make bump version=20xx.xx.xx )
|
||||
endif
|
||||
$(eval current_version := $(shell cat ${PWD}/internal/constants/VERSION))
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml
|
||||
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml ${PWD}/Cargo.toml
|
||||
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' ${PWD}/authentik/__init__.py
|
||||
$(MAKE) gen-build gen-compose aws-cfn
|
||||
$(SED_INPLACE) "s/\"${current_version}\"/\"$(version)\"/" ${PWD}/package.json ${PWD}/package-lock.json ${PWD}/web/package.json ${PWD}/web/package-lock.json
|
||||
@@ -359,13 +371,13 @@ ci-lint-pending-migrations: ci--meta-debug
|
||||
$(UV) run ak makemigrations --check
|
||||
|
||||
ci-lint-cargo-deny: ci--meta-debug
|
||||
$(CARGO) deny --locked --workspace check --config .config/deny.toml
|
||||
$(CARGO) deny --locked --workspace check --config .cargo/deny.toml
|
||||
|
||||
ci-lint-cargo-machete: ci--meta-debug
|
||||
$(CARGO) machete
|
||||
|
||||
ci-lint-rustfmt: ci--meta-debug
|
||||
$(CARGO) +nightly fmt --all --check -- --config-path .config/rustfmt.toml
|
||||
$(CARGO) +nightly fmt --all --check -- --config-path .cargo/rustfmt.toml
|
||||
|
||||
ci-lint-clippy: ci--meta-debug
|
||||
$(CARGO) clippy -- -D warnings
|
||||
|
||||
@@ -92,6 +92,7 @@ class FileBackend(ManageableBackend):
|
||||
"nbf": now() - timedelta(seconds=15),
|
||||
},
|
||||
key=sha256(f"{settings.SECRET_KEY}:{self.usage}".encode()).hexdigest(),
|
||||
# Must match crates/authentik-server/src/static.rs
|
||||
algorithm="HS256",
|
||||
)
|
||||
url = f"{prefix}/files/{path}?token={token}"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Apply blueprint from commandline"""
|
||||
|
||||
from sys import exit as sys_exit
|
||||
|
||||
from django.core.management.base import BaseCommand, no_translations
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
@@ -28,7 +26,7 @@ class Command(BaseCommand):
|
||||
self.stderr.write("Blueprint invalid")
|
||||
for log in logs:
|
||||
self.stderr.write(f"\t{log.logger}: {log.event}: {log.attributes}")
|
||||
sys_exit(1)
|
||||
raise RuntimeError("Blueprint invalid")
|
||||
importer.apply()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
||||
@@ -342,10 +342,10 @@ def django_db_config(config: ConfigLoader | None = None) -> dict:
|
||||
"default": {
|
||||
"ENGINE": "psqlextra.backend",
|
||||
"HOST": config.get("postgresql.host"),
|
||||
"NAME": config.get("postgresql.name"),
|
||||
"PORT": config.get("postgresql.port"),
|
||||
"USER": config.get("postgresql.user"),
|
||||
"PASSWORD": config.get("postgresql.password"),
|
||||
"PORT": config.get("postgresql.port"),
|
||||
"NAME": config.get("postgresql.name"),
|
||||
"OPTIONS": {
|
||||
"sslmode": config.get("postgresql.sslmode"),
|
||||
"sslrootcert": config.get("postgresql.sslrootcert"),
|
||||
@@ -423,4 +423,5 @@ if __name__ == "__main__":
|
||||
if len(argv) < 2: # noqa: PLR2004
|
||||
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
|
||||
else:
|
||||
print(CONFIG.get(argv[-1]))
|
||||
for arg in argv[1:]:
|
||||
print(CONFIG.get(arg))
|
||||
|
||||
@@ -17,11 +17,13 @@
|
||||
|
||||
postgresql:
|
||||
host: localhost
|
||||
name: authentik
|
||||
user: authentik
|
||||
port: 5432
|
||||
user: authentik
|
||||
password: "env://POSTGRES_PASSWORD"
|
||||
name: authentik
|
||||
sslmode: disable
|
||||
conn_max_age: 60
|
||||
conn_health_checks: false
|
||||
use_pool: False
|
||||
test:
|
||||
name: test_authentik
|
||||
@@ -72,6 +74,19 @@ log_level: info
|
||||
log:
|
||||
http_headers:
|
||||
- User-Agent
|
||||
rust_log:
|
||||
"console_subscriber": info
|
||||
"h2": info
|
||||
"hyper_util": warn
|
||||
"mio": info
|
||||
"notify": info
|
||||
"reqwest": info
|
||||
"runtime": info
|
||||
"rustls": info
|
||||
"sqlx": info
|
||||
"sqlx_postgres": info
|
||||
"tokio": info
|
||||
"tungstenite": info
|
||||
|
||||
sessions:
|
||||
unauthenticated_age: days=1
|
||||
@@ -143,8 +158,7 @@ tenants:
|
||||
blueprints_dir: /blueprints
|
||||
|
||||
web:
|
||||
# No default here as it's set dynamically
|
||||
# workers: 2
|
||||
workers: 2
|
||||
threads: 4
|
||||
path: /
|
||||
timeout_http_read_header: 5s
|
||||
|
||||
@@ -41,7 +41,7 @@ def structlog_configure():
|
||||
add_process_id,
|
||||
add_tenant_information,
|
||||
structlog.stdlib.PositionalArgumentsFormatter(),
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=False),
|
||||
structlog.processors.TimeStamper(fmt="iso", utc=True),
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.processors.ExceptionRenderer(
|
||||
structlog.tracebacks.ExceptionDictTransformer(show_locals=CONFIG.get_bool("debug"))
|
||||
|
||||
@@ -339,6 +339,9 @@ class LoggingMiddleware:
|
||||
|
||||
def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):
|
||||
"""Log request"""
|
||||
# Those are logged by the server above
|
||||
if request.path in ("/-/metrics/", "/-/health/ready/"):
|
||||
return
|
||||
for header in self.headers_to_log:
|
||||
header_value = request.headers.get(header)
|
||||
if not header_value:
|
||||
|
||||
@@ -1,37 +1,21 @@
|
||||
"""Metrics view"""
|
||||
|
||||
from hmac import compare_digest
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connections
|
||||
from django.db.utils import OperationalError
|
||||
from django.dispatch import Signal
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.views import View
|
||||
from django_prometheus.exports import ExportToDjangoView
|
||||
|
||||
monitoring_set = Signal()
|
||||
|
||||
|
||||
class MetricsView(View):
|
||||
"""Wrapper around ExportToDjangoView with authentication, accessed by the authentik router"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
_tmp = Path(gettempdir())
|
||||
with open(_tmp / "authentik-core-metrics.key") as _f:
|
||||
self.monitoring_key = _f.read()
|
||||
"""View for metrics monitoring_set signal, accessed by the authentik router"""
|
||||
|
||||
def get(self, request: HttpRequest) -> HttpResponse:
|
||||
"""Check for HTTP-Basic auth"""
|
||||
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
|
||||
auth_type, _, given_credentials = auth_header.partition(" ")
|
||||
authed = auth_type == "Bearer" and compare_digest(given_credentials, self.monitoring_key)
|
||||
if not authed and not settings.DEBUG:
|
||||
return HttpResponse(status=401)
|
||||
monitoring_set.send_robust(self)
|
||||
return ExportToDjangoView(request)
|
||||
return HttpResponse(status=204)
|
||||
|
||||
|
||||
class LiveView(View):
|
||||
|
||||
@@ -440,8 +440,6 @@ DRAMATIQ = {
|
||||
("authentik.tasks.middleware.TaskLogMiddleware", {}),
|
||||
("authentik.tasks.middleware.LoggingMiddleware", {}),
|
||||
("authentik.tasks.middleware.DescriptionMiddleware", {}),
|
||||
("authentik.tasks.middleware.WorkerHealthcheckMiddleware", {}),
|
||||
("authentik.tasks.middleware.WorkerStatusMiddleware", {}),
|
||||
(
|
||||
"authentik.tasks.middleware.MetricsMiddleware",
|
||||
{
|
||||
|
||||
@@ -14,12 +14,12 @@ class TestRoot(TransactionTestCase):
|
||||
def setUp(self):
|
||||
_tmp = Path(gettempdir())
|
||||
self.token = token_urlsafe(32)
|
||||
with open(_tmp / "authentik-core-metrics.key", "w") as _f:
|
||||
with open(_tmp / "authentik-metrics-gunicorn.key", "w") as _f:
|
||||
_f.write(self.token)
|
||||
|
||||
def tearDown(self):
|
||||
_tmp = Path(gettempdir())
|
||||
(_tmp / "authentik-core-metrics.key").unlink()
|
||||
(_tmp / "authentik-metrics-gunicorn.key").unlink()
|
||||
|
||||
def test_monitoring_error(self):
|
||||
"""Test monitoring without any credentials"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import pglock
|
||||
from django.utils.timezone import now, timedelta
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from drf_spectacular.utils import extend_schema, inline_serializer
|
||||
from packaging.version import parse
|
||||
from rest_framework.fields import BooleanField, CharField
|
||||
@@ -31,18 +32,13 @@ class WorkerView(APIView):
|
||||
def get(self, request: Request) -> Response:
|
||||
response = []
|
||||
our_version = parse(authentik_full_version())
|
||||
for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(minutes=2)):
|
||||
lock_id = f"goauthentik.io/worker/status/{status.pk}"
|
||||
with pglock.advisory(lock_id, timeout=0, side_effect=pglock.Return) as acquired:
|
||||
# The worker doesn't hold the lock, it isn't running
|
||||
if acquired:
|
||||
continue
|
||||
version_matching = parse(status.version) == our_version
|
||||
response.append(
|
||||
{
|
||||
"worker_id": f"{status.pk}@{status.hostname}",
|
||||
"version": status.version,
|
||||
"version_matching": version_matching,
|
||||
}
|
||||
)
|
||||
for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(seconds=45)):
|
||||
version_matching = parse(status.version) == our_version
|
||||
response.append(
|
||||
{
|
||||
"worker_id": f"{status.pk}@{status.hostname}",
|
||||
"version": status.version,
|
||||
"version_matching": version_matching,
|
||||
}
|
||||
)
|
||||
return Response(response)
|
||||
|
||||
@@ -1,42 +1,23 @@
|
||||
import socket
|
||||
from collections.abc import Callable
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from threading import Event as TEvent
|
||||
from threading import Thread, current_thread
|
||||
from typing import Any, cast
|
||||
|
||||
import pglock
|
||||
from django.db import OperationalError, connections, transaction
|
||||
from django.utils.timezone import now
|
||||
from django.db import OperationalError
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
CurrentTask as BaseCurrentTask,
|
||||
)
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
HTTPServer,
|
||||
HTTPServerThread,
|
||||
)
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
MetricsMiddleware as BaseMetricsMiddleware,
|
||||
)
|
||||
from django_dramatiq_postgres.middleware import (
|
||||
_MetricsHandler as BaseMetricsHandler,
|
||||
)
|
||||
from dramatiq import Worker
|
||||
from dramatiq.broker import Broker
|
||||
from dramatiq.message import Message
|
||||
from dramatiq.middleware import Middleware
|
||||
from psycopg.errors import Error
|
||||
from setproctitle import setthreadtitle
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik import authentik_full_version
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.config import CONFIG
|
||||
from authentik.lib.sentry import should_ignore_exception
|
||||
from authentik.lib.utils.reflection import class_to_path
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
from authentik.root.signals import post_startup, pre_startup, startup
|
||||
from authentik.tasks.models import Task, TaskLog, TaskStatus, WorkerStatus
|
||||
from authentik.tasks.models import Task, TaskLog, TaskStatus
|
||||
from authentik.tenants.models import Tenant
|
||||
from authentik.tenants.utils import get_current_tenant
|
||||
|
||||
@@ -193,154 +174,15 @@ class DescriptionMiddleware(Middleware):
|
||||
return {"description"}
|
||||
|
||||
|
||||
class _healthcheck_handler(BaseHTTPRequestHandler):
|
||||
def log_request(self, code="-", size="-"):
|
||||
HEALTHCHECK_LOGGER.info(
|
||||
self.path,
|
||||
method=self.command,
|
||||
status=code,
|
||||
)
|
||||
|
||||
def log_error(self, format, *args):
|
||||
HEALTHCHECK_LOGGER.warning(format, *args)
|
||||
|
||||
def do_HEAD(self):
|
||||
try:
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
self.send_response(200)
|
||||
except DB_ERRORS: # pragma: no cover
|
||||
self.send_response(503)
|
||||
self.send_header("Content-Type", "text/plain; charset=utf-8")
|
||||
self.send_header("Content-Length", "0")
|
||||
self.end_headers()
|
||||
|
||||
do_GET = do_HEAD
|
||||
|
||||
|
||||
class WorkerHealthcheckMiddleware(Middleware):
|
||||
thread: HTTPServerThread | None
|
||||
|
||||
def __init__(self):
|
||||
listen = CONFIG.get("listen.http", ["[::]:9000"])
|
||||
if isinstance(listen, str):
|
||||
listen = listen.split(",")
|
||||
host, _, port = listen[0].rpartition(":")
|
||||
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
LOGGER.error(f"Invalid port entered: {port}")
|
||||
|
||||
self.host, self.port = host, port
|
||||
|
||||
def after_worker_boot(self, broker: Broker, worker: Worker):
|
||||
self.thread = HTTPServerThread(
|
||||
target=WorkerHealthcheckMiddleware.run, args=(self.host, self.port)
|
||||
)
|
||||
self.thread.start()
|
||||
|
||||
def before_worker_shutdown(self, broker: Broker, worker: Worker):
|
||||
server = self.thread.server
|
||||
if server:
|
||||
server.shutdown()
|
||||
LOGGER.debug("Stopping WorkerHealthcheckMiddleware")
|
||||
self.thread.join()
|
||||
|
||||
@staticmethod
|
||||
def run(addr: str, port: int):
|
||||
setthreadtitle("authentik Worker Healthcheck server")
|
||||
try:
|
||||
server = HTTPServer((addr, port), _healthcheck_handler)
|
||||
thread = cast(HTTPServerThread, current_thread())
|
||||
thread.server = server
|
||||
server.serve_forever()
|
||||
except OSError as exc:
|
||||
get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning(
|
||||
"Port is already in use, not starting healthcheck server",
|
||||
exc=exc,
|
||||
)
|
||||
|
||||
|
||||
class WorkerStatusMiddleware(Middleware):
|
||||
thread: Thread | None
|
||||
thread_event: TEvent | None
|
||||
|
||||
def after_worker_boot(self, broker: Broker, worker: Worker):
|
||||
self.thread_event = TEvent()
|
||||
self.thread = Thread(target=WorkerStatusMiddleware.run, args=(self.thread_event,))
|
||||
self.thread.start()
|
||||
|
||||
def before_worker_shutdown(self, broker: Broker, worker: Worker):
|
||||
self.thread_event.set()
|
||||
LOGGER.debug("Stopping WorkerStatusMiddleware")
|
||||
self.thread.join()
|
||||
|
||||
@staticmethod
|
||||
def run(event: TEvent):
|
||||
setthreadtitle("authentik Worker status")
|
||||
with transaction.atomic():
|
||||
hostname = socket.gethostname()
|
||||
WorkerStatus.objects.filter(hostname=hostname).delete()
|
||||
status, _ = WorkerStatus.objects.update_or_create(
|
||||
hostname=hostname,
|
||||
version=authentik_full_version(),
|
||||
)
|
||||
while not event.is_set():
|
||||
try:
|
||||
WorkerStatusMiddleware.keep(event, status)
|
||||
except DB_ERRORS: # pragma: no cover
|
||||
event.wait(10)
|
||||
try:
|
||||
connections.close_all()
|
||||
except DB_ERRORS:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def keep(event: TEvent, status: WorkerStatus):
|
||||
lock_id = f"goauthentik.io/worker/status/{status.pk}"
|
||||
with pglock.advisory(lock_id, side_effect=pglock.Raise):
|
||||
while not event.is_set():
|
||||
status.refresh_from_db()
|
||||
old_last_seen = status.last_seen
|
||||
status.last_seen = now()
|
||||
if old_last_seen != status.last_seen:
|
||||
status.save(update_fields=("last_seen",))
|
||||
event.wait(30)
|
||||
|
||||
|
||||
class _MetricsHandler(BaseMetricsHandler):
|
||||
def do_GET(self) -> None:
|
||||
monitoring_set.send_robust(self)
|
||||
return super().do_GET()
|
||||
|
||||
|
||||
class MetricsMiddleware(BaseMetricsMiddleware):
|
||||
thread: HTTPServerThread | None
|
||||
handler_class = _MetricsHandler
|
||||
|
||||
@property
|
||||
def forks(self) -> list[Callable[[], None]]:
|
||||
def forks(self):
|
||||
return []
|
||||
|
||||
def after_worker_boot(self, broker: Broker, worker: Worker):
|
||||
listen = CONFIG.get("listen.metrics", ["[::]:9300"])
|
||||
if isinstance(listen, str):
|
||||
listen = listen.split(",")
|
||||
addr, _, port = listen[0].rpartition(":")
|
||||
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
|
||||
from prometheus_client import values
|
||||
from prometheus_client.values import MultiProcessValue
|
||||
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
LOGGER.error(f"Invalid port entered: {port}")
|
||||
self.thread = HTTPServerThread(target=MetricsMiddleware.run, args=(addr, port))
|
||||
self.thread.start()
|
||||
values.ValueClass = MultiProcessValue(lambda: worker.worker_id)
|
||||
|
||||
def before_worker_shutdown(self, broker: Broker, worker: Worker):
|
||||
server = self.thread.server
|
||||
if server:
|
||||
server.shutdown()
|
||||
LOGGER.debug("Stopping MetricsMiddleware")
|
||||
self.thread.join()
|
||||
return super().before_worker_boot(broker, worker)
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from django.utils.timezone import now, timedelta
|
||||
from datetime import timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq import actor
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from dramatiq.results.middleware import Results
|
||||
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
|
||||
|
||||
from authentik.tasks.broker import PostgresBroker
|
||||
from authentik.tasks.middleware import WorkerHealthcheckMiddleware
|
||||
|
||||
TESTING_QUEUE = "testing"
|
||||
|
||||
@@ -18,6 +17,7 @@ TESTING_QUEUE = "testing"
|
||||
class TestWorker(Worker):
|
||||
def __init__(self, broker: Broker):
|
||||
super().__init__(broker=broker)
|
||||
self.worker_id = 1000
|
||||
self.work_queue = PriorityQueue()
|
||||
self.consumers = {
|
||||
TESTING_QUEUE: _ConsumerThread(
|
||||
@@ -82,8 +82,6 @@ def use_test_broker():
|
||||
middleware: Middleware = import_string(middleware_class)(
|
||||
**middleware_kwargs,
|
||||
)
|
||||
if isinstance(middleware, WorkerHealthcheckMiddleware):
|
||||
middleware.port = 9102
|
||||
if isinstance(middleware, Retries):
|
||||
middleware.max_retries = 0
|
||||
if isinstance(middleware, Results):
|
||||
|
||||
86
lifecycle/ak
86
lifecycle/ak
@@ -1,10 +1,6 @@
|
||||
#!/usr/bin/env -S bash
|
||||
set -e -o pipefail
|
||||
MODE_FILE="${TMPDIR}/authentik-mode"
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [[ -z "${PROMETHEUS_MULTIPROC_DIR}" ]]; then
|
||||
export PROMETHEUS_MULTIPROC_DIR="${TMPDIR:-/tmp}/authentik_prometheus_tmp"
|
||||
fi
|
||||
set -e -o pipefail
|
||||
|
||||
function log {
|
||||
printf '{"event": "%s", "level": "info", "logger": "bootstrap"}\n' "$@" >&2
|
||||
@@ -15,10 +11,18 @@ function wait_for_db {
|
||||
log "Bootstrap completed"
|
||||
}
|
||||
|
||||
function check_if_root {
|
||||
function run_authentik {
|
||||
if [[ -x "$(command -v authentik)" ]]; then
|
||||
echo authentik "$@"
|
||||
else
|
||||
echo cargo run -- "$@"
|
||||
fi
|
||||
}
|
||||
|
||||
function check_if_root_and_run {
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
log "Not running as root, disabling permission fixes"
|
||||
exec $1
|
||||
exec $(run_authentik "$@")
|
||||
return
|
||||
fi
|
||||
SOCKET="/var/run/docker.sock"
|
||||
@@ -26,36 +30,19 @@ function check_if_root {
|
||||
if [[ -e "$SOCKET" ]]; then
|
||||
# Get group ID of the docker socket, so we can create a matching group and
|
||||
# add ourselves to it
|
||||
DOCKER_GID=$(stat -c '%g' $SOCKET)
|
||||
DOCKER_GID="$(stat -c "%g" "${SOCKET}")"
|
||||
# Ensure group for the id exists
|
||||
getent group $DOCKER_GID || groupadd -f -g $DOCKER_GID docker
|
||||
usermod -a -G $DOCKER_GID authentik
|
||||
getent group "${DOCKER_GID}" || groupadd -f -g "${DOCKER_GID}" docker
|
||||
usermod -a -G "${DOCKER_GID}" authentik
|
||||
# since the name of the group might not be docker, we need to lookup the group id
|
||||
GROUP_NAME=$(getent group $DOCKER_GID | sed 's/:/\n/g' | head -1)
|
||||
GROUP_NAME=$(getent group "${DOCKER_GID}" | sed 's/:/\n/g' | head -1)
|
||||
GROUP="authentik:${GROUP_NAME}"
|
||||
fi
|
||||
# Fix permissions of certs and media
|
||||
chown -R authentik:authentik /data /certs "${PROMETHEUS_MULTIPROC_DIR}"
|
||||
chmod ug+rwx /data
|
||||
chmod ug+rx /certs
|
||||
exec chpst -u authentik:$GROUP env HOME=/authentik $1
|
||||
}
|
||||
|
||||
function run_authentik {
|
||||
if [[ -x "$(command -v authentik)" ]]; then
|
||||
exec authentik $@
|
||||
else
|
||||
exec go run -v ./cmd/server/ $@
|
||||
fi
|
||||
}
|
||||
|
||||
function set_mode {
|
||||
echo $1 >$MODE_FILE
|
||||
trap cleanup EXIT
|
||||
}
|
||||
|
||||
function cleanup {
|
||||
rm -f ${MODE_FILE}
|
||||
exec chpst -u authentik:"${GROUP}" env HOME=/authentik $(run_authentik "$@")
|
||||
}
|
||||
|
||||
function prepare_debug {
|
||||
@@ -72,38 +59,31 @@ function prepare_debug {
|
||||
chown authentik:authentik /unittest.xml
|
||||
}
|
||||
|
||||
if [[ -z "${PROMETHEUS_MULTIPROC_DIR}" ]]; then
|
||||
export PROMETHEUS_MULTIPROC_DIR="${TMPDIR:-/tmp}/authentik_prometheus_tmp"
|
||||
fi
|
||||
mkdir -p "${PROMETHEUS_MULTIPROC_DIR}"
|
||||
|
||||
if [[ "$(python -m authentik.lib.config debugger 2>/dev/null)" == "True" ]]; then
|
||||
prepare_debug
|
||||
fi
|
||||
|
||||
if [[ "$1" == "server" ]]; then
|
||||
set_mode "server"
|
||||
run_authentik
|
||||
elif [[ "$1" == "worker" ]]; then
|
||||
set_mode "worker"
|
||||
shift
|
||||
# If we have bootstrap credentials set, run bootstrap tasks outside of main server
|
||||
# sync, so that we can sure the first start actually has working bootstrap
|
||||
# credentials
|
||||
if [[ -n "${AUTHENTIK_BOOTSTRAP_PASSWORD}" || -n "${AUTHENTIK_BOOTSTRAP_TOKEN}" ]]; then
|
||||
python -m manage apply_blueprint system/bootstrap.yaml || true
|
||||
fi
|
||||
check_if_root "python -m manage worker --pid-file ${TMPDIR}/authentik-worker.pid $@"
|
||||
elif [[ "$1" == "bash" ]]; then
|
||||
/bin/bash
|
||||
elif [[ "$1" == "test-all" ]]; then
|
||||
prepare_debug
|
||||
chmod 777 /root
|
||||
check_if_root "python -m manage test authentik"
|
||||
elif [[ "$1" == "healthcheck" ]]; then
|
||||
run_authentik healthcheck $(cat $MODE_FILE)
|
||||
if [[ "$1" == "bash" ]]; then
|
||||
exec /usr/bin/env -S bash "$@"
|
||||
elif [[ "$1" == "dump_config" ]]; then
|
||||
shift
|
||||
exec python -m authentik.lib.config $@
|
||||
shift 1
|
||||
exec python -m authentik.lib.config "$@"
|
||||
elif [[ "$1" == "debug" ]]; then
|
||||
exec sleep infinity
|
||||
elif [[ "$1" == "test-all" ]]; then
|
||||
wait_for_db
|
||||
prepare_debug
|
||||
chmod 777 /root
|
||||
check_if_root_and_run manage test authentik
|
||||
elif [[ "$1" == "allinone" ]] || [[ "$1" == "server" ]] || [[ "$1" == "worker" ]] || [[ "$1" == "proxy" ]] || [[ "$1" == "manage" ]]; then
|
||||
wait_for_db
|
||||
check_if_root_and_run "$@"
|
||||
else
|
||||
wait_for_db
|
||||
exec python -m manage "$@"
|
||||
fi
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""Gunicorn config"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
import signal
|
||||
from hashlib import sha512
|
||||
from pathlib import Path
|
||||
from tempfile import gettempdir
|
||||
@@ -17,7 +19,6 @@ from authentik.lib.utils.reflection import get_env
|
||||
from authentik.root.install_id import get_install_id_raw
|
||||
from authentik.root.setup import setup
|
||||
from lifecycle.migrate import run_migrations
|
||||
from lifecycle.wait_for_db import wait_for_db
|
||||
from lifecycle.worker import DjangoUvicornWorker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,16 +29,12 @@ if TYPE_CHECKING:
|
||||
|
||||
setup()
|
||||
|
||||
wait_for_db()
|
||||
|
||||
_tmp = Path(gettempdir())
|
||||
worker_class = "lifecycle.worker.DjangoUvicornWorker"
|
||||
worker_tmp_dir = str(_tmp.joinpath("authentik_gunicorn_tmp"))
|
||||
|
||||
os.makedirs(worker_tmp_dir, exist_ok=True)
|
||||
|
||||
bind = f"unix://{str(_tmp.joinpath('authentik-core.sock'))}"
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
|
||||
preload_app = True
|
||||
@@ -45,13 +42,28 @@ preload_app = True
|
||||
max_requests = CONFIG.get_int("web.max_requests", 1000)
|
||||
max_requests_jitter = CONFIG.get_int("web.max_requests_jitter", 50)
|
||||
|
||||
# Match the value in src/arbiter.rs for graceful shutdown
|
||||
dirty_graceful_timeout = 30
|
||||
|
||||
logconfig_dict = get_logger_config()
|
||||
|
||||
default_workers = 2
|
||||
|
||||
workers = CONFIG.get_int("web.workers", default_workers)
|
||||
workers = CONFIG.get_int("web.workers", 2)
|
||||
threads = CONFIG.get_int("web.threads", 4)
|
||||
|
||||
# libpq can try Kerberos/GSS on macOS, which is not fork-safe in our Gunicorn worker model.
|
||||
# Disable GSS negotiation for local/dev PostgreSQL connections on Darwin.
|
||||
if platform.system() == "Darwin":
|
||||
os.environ.setdefault("PGGSSENCMODE", "disable")
|
||||
# Avoid macOS SystemConfiguration proxy lookups (_scproxy) in forked workers.
|
||||
# urllib/requests may consult these APIs and can crash in child workers.
|
||||
os.environ.setdefault("NO_PROXY", "*")
|
||||
os.environ.setdefault("no_proxy", "*")
|
||||
|
||||
|
||||
def when_ready(server: "Arbiter"): # noqa: UP037
|
||||
# Notify rust process that we are ready
|
||||
os.kill(os.getppid(), signal.SIGUSR1)
|
||||
|
||||
|
||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): # noqa: UP037
|
||||
"""Tell prometheus to use worker number instead of process ID for multiprocess"""
|
||||
|
||||
148
lifecycle/worker_process.py
Normal file
148
lifecycle/worker_process.py
Normal file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import sys
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from socket import AF_UNIX
|
||||
from threading import Event, Thread
|
||||
from typing import Any
|
||||
|
||||
from dramatiq import Worker, get_broker
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.lib.config import CONFIG
|
||||
|
||||
LOGGER = get_logger()
|
||||
INITIAL_WORKER_ID = 1000
|
||||
|
||||
|
||||
class HttpHandler(BaseHTTPRequestHandler):
|
||||
def check_db(self):
|
||||
from django.db import connections
|
||||
|
||||
for db_conn in connections.all():
|
||||
# Force connection reload
|
||||
db_conn.connect()
|
||||
_ = db_conn.cursor()
|
||||
|
||||
def do_GET(self):
|
||||
|
||||
if self.path == "/-/metrics/":
|
||||
from authentik.root.monitoring import monitoring_set
|
||||
|
||||
monitoring_set.send_robust(self)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
elif self.path == "/-/health/ready/":
|
||||
from django.db.utils import OperationalError
|
||||
|
||||
try:
|
||||
self.check_db()
|
||||
except OperationalError:
|
||||
self.send_response(503)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
else:
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class UnixSocketServer(HTTPServer):
|
||||
address_family = AF_UNIX
|
||||
|
||||
|
||||
def main(worker_id: int, socket_path: str | None):
|
||||
shutdown = Event()
|
||||
srv = None
|
||||
|
||||
def immediate_shutdown(signum, frame):
|
||||
nonlocal srv
|
||||
if srv is not None:
|
||||
srv.shutdown()
|
||||
if socket_path:
|
||||
os.remove(socket_path)
|
||||
sys.exit(0)
|
||||
|
||||
def graceful_shutdown(signum, frame):
|
||||
nonlocal shutdown
|
||||
shutdown.set()
|
||||
|
||||
signal.signal(signal.SIGHUP, immediate_shutdown)
|
||||
signal.signal(signal.SIGINT, immediate_shutdown)
|
||||
signal.signal(signal.SIGQUIT, immediate_shutdown)
|
||||
signal.signal(signal.SIGTERM, graceful_shutdown)
|
||||
|
||||
random.seed()
|
||||
|
||||
logger = LOGGER.bind(worker_id=worker_id)
|
||||
|
||||
logger.debug("Loading broker...")
|
||||
broker = get_broker()
|
||||
broker.emit_after("process_boot")
|
||||
|
||||
logger.debug("Starting worker threads...")
|
||||
queues = None # all queues
|
||||
worker = Worker(broker, queues=queues, worker_threads=CONFIG.get_int("worker.threads"))
|
||||
worker.worker_id = worker_id
|
||||
worker.start()
|
||||
logger.info("Worker process is ready for action.")
|
||||
|
||||
if socket_path:
|
||||
srv = UnixSocketServer(socket_path, HttpHandler)
|
||||
Thread(target=srv.serve_forever).start()
|
||||
|
||||
# Notify rust process that we are ready
|
||||
os.kill(os.getppid(), signal.SIGUSR2)
|
||||
|
||||
shutdown.wait()
|
||||
|
||||
logger.info("Shutting down worker...")
|
||||
if srv is not None:
|
||||
srv.shutdown()
|
||||
if socket_path:
|
||||
os.remove(socket_path)
|
||||
# 5 secs if debug, 5 mins otherwise
|
||||
worker.stop(timeout=5_000 if CONFIG.get_bool("debug") else 600_000)
|
||||
broker.close()
|
||||
logger.info("Worker shut down.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) not in [2, 3]:
|
||||
print("USAGE: worker_process <worker_id> [SOCKET_PATH]")
|
||||
sys.exit(1)
|
||||
|
||||
worker_id = int(sys.argv[1])
|
||||
socket_path = sys.argv[2] if len(sys.argv) == 3 else None # noqa: PLR2004
|
||||
|
||||
from authentik.root.setup import setup
|
||||
|
||||
setup()
|
||||
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
|
||||
|
||||
import django
|
||||
|
||||
django.setup()
|
||||
|
||||
from django.core.management import execute_from_command_line
|
||||
|
||||
if socket_path:
|
||||
from lifecycle.migrate import run_migrations
|
||||
|
||||
run_migrations()
|
||||
|
||||
if (
|
||||
"AUTHENTIK_BOOTSTRAP_PASSWORD" in os.environ
|
||||
or "AUTHENTIK_BOOTSTRAP_TOKEN" in os.environ
|
||||
):
|
||||
try:
|
||||
execute_from_command_line(["", "apply_blueprint", "system/bootstrap.yaml"])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
sys.stderr.write(f"Failed to apply bootstrap blueprint: {exc}")
|
||||
|
||||
main(worker_id, socket_path)
|
||||
@@ -8,12 +8,10 @@ from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
|
||||
|
||||
from authentik.root.setup import setup
|
||||
from lifecycle.migrate import run_migrations
|
||||
from lifecycle.wait_for_db import wait_for_db
|
||||
|
||||
setup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
wait_for_db()
|
||||
if (
|
||||
len(sys.argv) > 1
|
||||
# Explicitly only run migrate for server and worker
|
||||
|
||||
291
src/arbiter.rs
Normal file
291
src/arbiter.rs
Normal file
@@ -0,0 +1,291 @@
|
||||
//! Utilities to manage long running tasks, such as servers and watchers.
|
||||
//!
|
||||
//! Also manages signals sent to the main process.
|
||||
|
||||
use std::{net, os::unix, sync::Arc, time::Duration};
|
||||
|
||||
use axum_server::Handle;
|
||||
use eyre::{Report, Result};
|
||||
use tokio::{
|
||||
signal::unix::{Signal, SignalKind, signal},
|
||||
sync::{Mutex, broadcast, watch},
|
||||
task::{JoinSet, join_set::Builder},
|
||||
};
|
||||
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
|
||||
use tracing::info;
|
||||
|
||||
/// All the signal streams we watch for. We don't create those directly in [`watch_signals`]
|
||||
/// because that would prevent us from handling errors early.
|
||||
struct SignalStreams {
|
||||
hup: Signal,
|
||||
int: Signal,
|
||||
quit: Signal,
|
||||
usr1: Signal,
|
||||
usr2: Signal,
|
||||
term: Signal,
|
||||
}
|
||||
|
||||
impl SignalStreams {
|
||||
fn new() -> Result<Self> {
|
||||
Ok(Self {
|
||||
hup: signal(SignalKind::hangup())?,
|
||||
int: signal(SignalKind::interrupt())?,
|
||||
quit: signal(SignalKind::quit())?,
|
||||
usr1: signal(SignalKind::user_defined1())?,
|
||||
usr2: signal(SignalKind::user_defined2())?,
|
||||
term: signal(SignalKind::terminate())?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Watch for incoming signals and either shutdown the application or dispatch them to receivers.
|
||||
async fn watch_signals(
|
||||
streams: SignalStreams,
|
||||
arbiter: Arbiter,
|
||||
_signals_rx: broadcast::Receiver<SignalKind>,
|
||||
) -> Result<()> {
|
||||
info!("starting signals watcher");
|
||||
let SignalStreams {
|
||||
mut hup,
|
||||
mut int,
|
||||
mut quit,
|
||||
mut usr1,
|
||||
mut usr2,
|
||||
mut term,
|
||||
} = streams;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = hup.recv() => {
|
||||
info!("signal HUP received");
|
||||
arbiter.do_fast_shutdown().await;
|
||||
},
|
||||
_ = int.recv() => {
|
||||
info!("signal INT received");
|
||||
arbiter.do_fast_shutdown().await;
|
||||
},
|
||||
_ = quit.recv() => {
|
||||
info!("signal QUIT received");
|
||||
arbiter.do_fast_shutdown().await;
|
||||
},
|
||||
_ = usr1.recv() => {
|
||||
info!("signal URS1 received");
|
||||
arbiter.signals_tx.send(SignalKind::user_defined1())?;
|
||||
},
|
||||
_ = usr2.recv() => {
|
||||
info!("USR2 received.");
|
||||
arbiter.signals_tx.send(SignalKind::user_defined2())?;
|
||||
},
|
||||
_ = term.recv() => {
|
||||
info!("signal TERM received");
|
||||
arbiter.do_graceful_shutdown().await;
|
||||
},
|
||||
() = arbiter.shutdown() => {
|
||||
info!("stopping signals watcher");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Manager for long running tasks, such as servers and watchers.
|
||||
pub(crate) struct Tasks {
|
||||
pub(crate) tasks: JoinSet<Result<()>>,
|
||||
arbiter: Arbiter,
|
||||
}
|
||||
|
||||
impl Tasks {
|
||||
pub(crate) fn new() -> Result<Self> {
|
||||
let mut tasks = JoinSet::new();
|
||||
let arbiter = Arbiter::new(&mut tasks)?;
|
||||
|
||||
Ok(Self { tasks, arbiter })
|
||||
}
|
||||
|
||||
/// Build a new task. See [`tokio::task::JoinSet::build_task`] for details.
|
||||
pub(crate) fn build_task(&mut self) -> Builder<'_, Result<()>> {
|
||||
self.tasks.build_task()
|
||||
}
|
||||
|
||||
/// Get an [`Arbiter`]
|
||||
pub(crate) fn arbiter(&self) -> Arbiter {
|
||||
self.arbiter.clone()
|
||||
}
|
||||
|
||||
pub(crate) async fn run(self) -> Vec<Report> {
|
||||
let Self { mut tasks, arbiter } = self;
|
||||
|
||||
let mut errors = Vec::new();
|
||||
|
||||
if let Some(result) = tasks.join_next().await {
|
||||
arbiter.do_graceful_shutdown().await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => {
|
||||
arbiter.do_fast_shutdown().await;
|
||||
errors.push(err);
|
||||
}
|
||||
Err(err) => {
|
||||
arbiter.do_fast_shutdown().await;
|
||||
errors.push(Report::new(err));
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = tasks.join_next().await {
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => errors.push(err),
|
||||
Err(err) => errors.push(Report::new(err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
errors
|
||||
}
|
||||
}
|
||||
|
||||
/// Manage shutdown state and several communication channels.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct Arbiter {
|
||||
/// Token to shutdown the application immediately.
|
||||
fast_shutdown: CancellationToken,
|
||||
/// Token to shutdown the application gracefully.
|
||||
graceful_shutdown: CancellationToken,
|
||||
/// Token set when any shutdown is triggered.
|
||||
shutdown: CancellationToken,
|
||||
|
||||
/// Axum handles to manage
|
||||
net_handles: Arc<Mutex<Vec<Handle<net::SocketAddr>>>>,
|
||||
unix_handles: Arc<Mutex<Vec<Handle<unix::net::SocketAddr>>>>,
|
||||
|
||||
/// Broadcaster of signals sent to the main process.
|
||||
signals_tx: broadcast::Sender<SignalKind>,
|
||||
/// Watcher of config change events
|
||||
config_changed_tx: watch::Sender<()>,
|
||||
_config_changed_rx: watch::Receiver<()>,
|
||||
|
||||
/// Token set when gunicorn is marked ready
|
||||
gunicorn_ready: CancellationToken,
|
||||
}
|
||||
|
||||
impl Arbiter {
|
||||
fn new(tasks: &mut JoinSet<Result<()>>) -> Result<Self> {
|
||||
let (signals_tx, signals_rx) = broadcast::channel(10);
|
||||
let (config_changed_tx, config_changed_rx) = watch::channel(());
|
||||
let arbiter = Self {
|
||||
fast_shutdown: CancellationToken::new(),
|
||||
graceful_shutdown: CancellationToken::new(),
|
||||
shutdown: CancellationToken::new(),
|
||||
|
||||
// 5 is http, https, metrics and a bit of room
|
||||
net_handles: Arc::new(Mutex::new(Vec::with_capacity(5))),
|
||||
// 2 is http and metrics
|
||||
unix_handles: Arc::new(Mutex::new(Vec::with_capacity(2))),
|
||||
|
||||
signals_tx,
|
||||
config_changed_tx,
|
||||
_config_changed_rx: config_changed_rx,
|
||||
|
||||
gunicorn_ready: CancellationToken::new(),
|
||||
};
|
||||
|
||||
let streams = SignalStreams::new()?;
|
||||
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_signals", module_path!()))
|
||||
.spawn(watch_signals(streams, arbiter.clone(), signals_rx))?;
|
||||
|
||||
Ok(arbiter)
|
||||
}
|
||||
|
||||
pub(crate) async fn add_net_handle(&self, handle: Handle<net::SocketAddr>) {
|
||||
self.net_handles.lock().await.push(handle);
|
||||
}
|
||||
|
||||
pub(crate) async fn add_unix_handle(&self, handle: Handle<unix::net::SocketAddr>) {
|
||||
self.unix_handles.lock().await.push(handle);
|
||||
}
|
||||
|
||||
/// Future that will complete when the application needs to shutdown immediately.
|
||||
pub(crate) fn fast_shutdown(&self) -> WaitForCancellationFuture<'_> {
|
||||
self.fast_shutdown.cancelled()
|
||||
}
|
||||
|
||||
/// Future that will complete when the application needs to shutdown gracefully.
|
||||
pub(crate) fn graceful_shutdown(&self) -> WaitForCancellationFuture<'_> {
|
||||
self.graceful_shutdown.cancelled()
|
||||
}
|
||||
|
||||
/// Future that will complete when the application needs to shutdown, either immediately or
|
||||
/// gracefully. It's a helper so users that don't make the difference between immediate and
|
||||
/// graceful shutdown don't need to handle two scenarios.
|
||||
pub(crate) fn shutdown(&self) -> WaitForCancellationFuture<'_> {
|
||||
self.shutdown.cancelled()
|
||||
}
|
||||
|
||||
/// Shutdown the application immediately.
|
||||
async fn do_fast_shutdown(&self) {
|
||||
info!("arbiter has been told to shutdown immediately");
|
||||
self.unix_handles
|
||||
.lock()
|
||||
.await
|
||||
.iter()
|
||||
.for_each(Handle::shutdown);
|
||||
self.net_handles
|
||||
.lock()
|
||||
.await
|
||||
.iter()
|
||||
.for_each(Handle::shutdown);
|
||||
info!("all webservers have been shutdown, shutting down the other tasks immediately");
|
||||
self.fast_shutdown.cancel();
|
||||
self.shutdown.cancel();
|
||||
}
|
||||
|
||||
/// Shutdown the application gracefully.
|
||||
async fn do_graceful_shutdown(&self) {
|
||||
info!("arbiter has been told to shutdown gracefully");
|
||||
// Match the value in lifecycle/gunicorn.conf.py for graceful shutdown
|
||||
let timeout = Some(Duration::from_secs(30 + 5));
|
||||
self.unix_handles
|
||||
.lock()
|
||||
.await
|
||||
.iter()
|
||||
.for_each(|handle| handle.graceful_shutdown(timeout));
|
||||
self.net_handles
|
||||
.lock()
|
||||
.await
|
||||
.iter()
|
||||
.for_each(|handle| handle.graceful_shutdown(timeout));
|
||||
info!("all webservers have been shutdown, shutting down the other tasks gracefully");
|
||||
self.graceful_shutdown.cancel();
|
||||
self.shutdown.cancel();
|
||||
}
|
||||
|
||||
/// Create a new [`broadcast::Receiver`] to listen for signals sent to the main process. This
|
||||
/// may not include all signals we catch, since some of those will shutdown the application.
|
||||
pub(crate) fn signals_subscribe(&self) -> broadcast::Receiver<SignalKind> {
|
||||
self.signals_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Send a value on the config changes watch channel
|
||||
pub(crate) fn config_changed_send(&self, value: ()) -> Result<()> {
|
||||
self.config_changed_tx.send(value)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new [`watch::Receiver`] to listen for detected configuration changes.
|
||||
pub(crate) fn config_changed_subscribe(&self) -> watch::Receiver<()> {
|
||||
self.config_changed_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Future that will complete when the application needs to shutdown gracefully.
|
||||
pub(crate) fn gunicorn_ready(&self) -> WaitForCancellationFuture<'_> {
|
||||
self.gunicorn_ready.cancelled()
|
||||
}
|
||||
|
||||
/// Mark gunicorn as ready
|
||||
pub(crate) fn mark_gunicorn_ready(&self) {
|
||||
self.gunicorn_ready.cancel();
|
||||
}
|
||||
}
|
||||
2
src/axum/accept/mod.rs
Normal file
2
src/axum/accept/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub(crate) mod proxy_protocol;
|
||||
pub(crate) mod tls;
|
||||
86
src/axum/accept/proxy_protocol.rs
Normal file
86
src/axum/accept/proxy_protocol.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::{io, time::Duration};
|
||||
|
||||
use axum::{Extension, middleware::AddExtension};
|
||||
use axum_server::accept::{Accept, DefaultAcceptor};
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tower::Layer as _;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ProxyProtocolState {
|
||||
pub(crate) header: Option<Header<'static>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ProxyProtocolAcceptor<A = DefaultAcceptor> {
|
||||
inner: A,
|
||||
parsing_timeout: Duration,
|
||||
}
|
||||
|
||||
impl ProxyProtocolAcceptor {
|
||||
pub(crate) fn new() -> Self {
|
||||
let inner = DefaultAcceptor::new();
|
||||
|
||||
#[cfg(not(test))]
|
||||
let parsing_timeout = Duration::from_secs(10);
|
||||
|
||||
// Don't force tests to wait too long
|
||||
#[cfg(test)]
|
||||
let parsing_timeout = Duration::from_secs(1);
|
||||
|
||||
Self {
|
||||
inner,
|
||||
parsing_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProxyProtocolAcceptor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<A> ProxyProtocolAcceptor<A> {
|
||||
pub(crate) fn acceptor<Acceptor>(self, acceptor: Acceptor) -> ProxyProtocolAcceptor<Acceptor> {
|
||||
ProxyProtocolAcceptor {
|
||||
inner: acceptor,
|
||||
parsing_timeout: self.parsing_timeout,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, I, S> Accept<I, S> for ProxyProtocolAcceptor<A>
|
||||
where
|
||||
A: Accept<I, S> + Clone + Send + 'static,
|
||||
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
A::Service: Send,
|
||||
A::Future: Send,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
S: Send + 'static,
|
||||
{
|
||||
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
|
||||
type Service = AddExtension<A::Service, ProxyProtocolState>;
|
||||
type Stream = ProxyProtocolStream<A::Stream>;
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn accept(&self, stream: I, service: S) -> Self::Future {
|
||||
let acceptor = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let (stream, service) = acceptor.accept(stream, service).await?;
|
||||
let stream = ProxyProtocolStream::new(stream).await?;
|
||||
|
||||
let proxy_protocol_state = ProxyProtocolState {
|
||||
header: stream.header().cloned(),
|
||||
};
|
||||
|
||||
let service = Extension(proxy_protocol_state).layer(service);
|
||||
|
||||
Ok((stream, service))
|
||||
})
|
||||
}
|
||||
}
|
||||
54
src/axum/accept/tls.rs
Normal file
54
src/axum/accept/tls.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use axum::{Extension, middleware::AddExtension};
|
||||
use axum_server::{accept::Accept, tls_rustls::RustlsAcceptor};
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::{rustls::pki_types::CertificateDer, server::TlsStream};
|
||||
use tower::Layer as _;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct TlsState {
|
||||
pub(crate) peer_certificates: Option<Vec<CertificateDer<'static>>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct TlsAcceptor<A> {
|
||||
inner: RustlsAcceptor<A>,
|
||||
}
|
||||
|
||||
impl<A> TlsAcceptor<A> {
|
||||
pub(crate) fn new(inner: RustlsAcceptor<A>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<A, I, S> Accept<I, S> for TlsAcceptor<A>
|
||||
where
|
||||
A: Accept<I, S> + Clone + Send + 'static,
|
||||
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
|
||||
A::Service: Send,
|
||||
A::Future: Send,
|
||||
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
S: Send + 'static,
|
||||
{
|
||||
type Future = BoxFuture<'static, std::io::Result<(Self::Stream, Self::Service)>>;
|
||||
type Service = AddExtension<A::Service, TlsState>;
|
||||
type Stream = TlsStream<A::Stream>;
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn accept(&self, stream: I, service: S) -> Self::Future {
|
||||
let acceptor = self.inner.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let (stream, service) = acceptor.accept(stream, service).await?;
|
||||
let server_conn = stream.get_ref().1;
|
||||
let tls_state = TlsState {
|
||||
peer_certificates: server_conn.peer_certificates().map(|c| c.to_owned()),
|
||||
};
|
||||
|
||||
let service = Extension(tls_state).layer(service);
|
||||
|
||||
Ok((stream, service))
|
||||
})
|
||||
}
|
||||
}
|
||||
26
src/axum/error.rs
Normal file
26
src/axum/error.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use eyre::Report;
|
||||
use tracing::warn;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct AppError(pub(crate) Report);
|
||||
|
||||
impl<E> From<E> for AppError
|
||||
where E: Into<Report>
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
warn!("error occurred: {:?}", self.0);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type Result<T, E = AppError> = core::result::Result<T, E>;
|
||||
228
src/axum/extract/client_ip.rs
Normal file
228
src/axum/extract/client_ip.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
|
||||
use axum::{
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{ConnectInfo, FromRequestParts, Request},
|
||||
http::request::Parts,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::{Span, instrument};
|
||||
|
||||
use crate::axum::{
|
||||
accept::proxy_protocol::ProxyProtocolState, extract::trusted_proxy::TrustedProxy,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct ClientIp(pub IpAddr);
|
||||
|
||||
impl<S> FromRequestParts<S> for ClientIp
|
||||
where S: Send + Sync
|
||||
{
|
||||
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
Extension::<Self>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map(|Extension(client_ip)| client_ip)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn extract_client_ip(parts: &mut Parts) -> IpAddr {
|
||||
let is_trusted = parts
|
||||
.extract::<TrustedProxy>()
|
||||
.await
|
||||
.unwrap_or(TrustedProxy(false))
|
||||
.0;
|
||||
|
||||
if is_trusted {
|
||||
if let Ok(ip) = client_ip::rightmost_x_forwarded_for(&parts.headers) {
|
||||
return ip;
|
||||
}
|
||||
|
||||
if let Ok(ip) = client_ip::x_real_ip(&parts.headers) {
|
||||
return ip;
|
||||
}
|
||||
|
||||
if let Ok(ip) = client_ip::rightmost_forwarded(&parts.headers) {
|
||||
return ip;
|
||||
}
|
||||
|
||||
if let Ok(Extension(proxy_protocol_state)) =
|
||||
parts.extract::<Extension<ProxyProtocolState>>().await
|
||||
&& let Some(header) = &proxy_protocol_state.header
|
||||
&& let Some(addr) = header.proxied_address()
|
||||
{
|
||||
return addr.source.ip();
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(ConnectInfo(addr)) = parts.extract::<ConnectInfo<SocketAddr>>().await {
|
||||
addr.ip()
|
||||
} else {
|
||||
// No connect info means we received a request via a Unix socket, hence localhost
|
||||
// as default
|
||||
Ipv6Addr::LOCALHOST.into()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn client_ip_middleware(request: Request, next: Next) -> Response {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
Span::current().record("remote", client_ip.to_string());
|
||||
parts.extensions.insert::<ClientIp>(ClientIp(client_ip));
|
||||
|
||||
let request = Request::from_parts(parts, body);
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use axum::{body::Body, http::Request};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_for_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-for", "192.0.2.51, 192.0.2.42")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_real_ip_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-real-ip", "192.0.2.42")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarded_header_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("forwarded", "for=192.0.2.42")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_connect_info() {
|
||||
let connect_addr: SocketAddr = "192.0.2.42:34932"
|
||||
.parse()
|
||||
.expect("Failed to parse socket address");
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.extension(ConnectInfo(connect_addr))
|
||||
.extension(TrustedProxy(false))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn headers_untrusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-for", "192.0.2.42")
|
||||
.extension(TrustedProxy(false))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_order() {
|
||||
// Test that X-Forwarded-For takes priority over other headers when trusted
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-for", "192.0.2.1")
|
||||
.header("x-real-ip", "192.0.2.2")
|
||||
.header("forwarded", "for=192.0.2.3")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 1),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_ip_found() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ipv6() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-for", "2001:db8::42")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42),);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_x_forwarded_for() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-for", "192.0.2.1, 192.0.2.2, 192.0.2.3")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let client_ip = extract_client_ip(&mut parts).await;
|
||||
|
||||
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 3),);
|
||||
}
|
||||
}
|
||||
262
src/axum/extract/host.rs
Normal file
262
src/axum/extract/host.rs
Normal file
@@ -0,0 +1,262 @@
|
||||
use axum::{
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{FromRequestParts, Request},
|
||||
http::{
|
||||
header::{FORWARDED, HOST},
|
||||
request::Parts,
|
||||
status::StatusCode,
|
||||
},
|
||||
middleware::Next,
|
||||
response::{IntoResponse as _, Response},
|
||||
};
|
||||
use forwarded_header_value::ForwardedHeaderValue;
|
||||
use tracing::{Span, instrument};
|
||||
|
||||
use crate::axum::extract::trusted_proxy::TrustedProxy;
|
||||
|
||||
const X_FORWARDED_HOST: &str = "X-Forwarded-Host";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Host(pub String);
|
||||
|
||||
impl<S> FromRequestParts<S> for Host
|
||||
where S: Send + Sync
|
||||
{
|
||||
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
Extension::<Self>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map(|Extension(host)| host)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn extract_host(parts: &mut Parts) -> Result<String, (StatusCode, &'static str)> {
|
||||
let is_trusted = parts
|
||||
.extract::<TrustedProxy>()
|
||||
.await
|
||||
.unwrap_or(TrustedProxy(false))
|
||||
.0;
|
||||
|
||||
if is_trusted {
|
||||
if let Some(host) = parts
|
||||
.headers
|
||||
.get(X_FORWARDED_HOST)
|
||||
.and_then(|host| host.to_str().ok())
|
||||
{
|
||||
return Ok(host.to_owned());
|
||||
}
|
||||
|
||||
if let Some(forwarded) = parts.headers.get(FORWARDED)
|
||||
&& let Ok(forwarded) = forwarded.to_str()
|
||||
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
|
||||
{
|
||||
for stanza in forwarded.iter() {
|
||||
if let Some(forwarded_host) = &stanza.forwarded_host {
|
||||
return Ok(forwarded_host.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(host) = parts.headers.get(HOST).and_then(|host| host.to_str().ok()) {
|
||||
return Ok(host.to_owned());
|
||||
}
|
||||
|
||||
if let Some(host) = parts.uri.host() {
|
||||
Ok(host.to_owned())
|
||||
} else {
|
||||
Err((StatusCode::BAD_REQUEST, "missing host header"))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn host_middleware(request: Request, next: Next) -> Response {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
let host = match extract_host(&mut parts).await {
|
||||
Ok(host) => host,
|
||||
Err(err) => return err.into_response(),
|
||||
};
|
||||
Span::current().record("host", host.clone());
|
||||
parts.extensions.insert::<Host>(Host(host));
|
||||
|
||||
let request = Request::from_parts(parts, body);
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::{body::Body, http::Request};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn host_header() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("host", "example.com:8080")
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"example.com:8080",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn from_uri() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com:8080/path")
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_host_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-host", "forwarded.example.com")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"forwarded.example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarded_header_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("forwarded", "host=forwarded.example.com")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"forwarded.example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarded_host_untrusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-host", "malicious.example.com")
|
||||
.extension(TrustedProxy(false))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarded_header_untrusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("forwarded", "host=malicious.example.com")
|
||||
.extension(TrustedProxy(false))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_order() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-host", "x-forwarded.example.com")
|
||||
.header("forwarded", "host=forwarded.example.com")
|
||||
.header("host", "host-header.example.com")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"x-forwarded.example.com",
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_host_found() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("/path")
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(result.expect_err("Host extract should fail").0, 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_forwarded_stanzas() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header(
|
||||
"forwarded",
|
||||
"host=first.example.com, host=second.example.com",
|
||||
)
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let result = extract_host(&mut parts).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(
|
||||
result.expect("Host extraction should succeed"),
|
||||
"first.example.com",
|
||||
);
|
||||
}
|
||||
}
|
||||
4
src/axum/extract/mod.rs
Normal file
4
src/axum/extract/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
pub(crate) mod client_ip;
|
||||
pub(crate) mod host;
|
||||
pub(crate) mod scheme;
|
||||
pub(crate) mod trusted_proxy;
|
||||
241
src/axum/extract/scheme.rs
Normal file
241
src/axum/extract/scheme.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
use axum::{
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{FromRequestParts, Request},
|
||||
http::{self, header::FORWARDED, request::Parts},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use forwarded_header_value::{ForwardedHeaderValue, Protocol};
|
||||
use tracing::{Span, instrument};
|
||||
|
||||
use crate::axum::{
|
||||
accept::{proxy_protocol::ProxyProtocolState, tls::TlsState},
|
||||
extract::trusted_proxy::TrustedProxy,
|
||||
};
|
||||
|
||||
const X_FORWARDED_PROTO: &str = "X-Forwarded-Proto";
|
||||
const X_FORWARDED_SCHEME: &str = "X-Forwarded-Scheme";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct Scheme(pub http::uri::Scheme);
|
||||
|
||||
impl<S> FromRequestParts<S> for Scheme
|
||||
where S: Send + Sync
|
||||
{
|
||||
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
Extension::<Self>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map(|Extension(scheme)| scheme)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn extract_scheme(parts: &mut Parts) -> http::uri::Scheme {
|
||||
let is_trusted = parts
|
||||
.extract::<TrustedProxy>()
|
||||
.await
|
||||
.unwrap_or(TrustedProxy(false))
|
||||
.0;
|
||||
|
||||
if is_trusted {
|
||||
if let Some(proto) = parts.headers.get(X_FORWARDED_PROTO)
|
||||
&& let Ok(proto) = proto.to_str()
|
||||
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
|
||||
{
|
||||
return scheme;
|
||||
}
|
||||
|
||||
if let Some(proto) = parts.headers.get(X_FORWARDED_SCHEME)
|
||||
&& let Ok(proto) = proto.to_str()
|
||||
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
|
||||
{
|
||||
return scheme;
|
||||
}
|
||||
|
||||
if let Some(forwarded) = parts.headers.get(FORWARDED)
|
||||
&& let Ok(forwarded) = forwarded.to_str()
|
||||
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
|
||||
{
|
||||
for stanza in forwarded.iter() {
|
||||
if let Some(forwarded_proto) = &stanza.forwarded_proto {
|
||||
let scheme = match forwarded_proto {
|
||||
Protocol::Http => http::uri::Scheme::HTTP,
|
||||
Protocol::Https => http::uri::Scheme::HTTPS,
|
||||
};
|
||||
return scheme;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(Extension(proxy_protocol_state)) =
|
||||
parts.extract::<Extension<ProxyProtocolState>>().await
|
||||
&& let Some(header) = &proxy_protocol_state.header
|
||||
&& let Some(_) = header.ssl()
|
||||
{
|
||||
return http::uri::Scheme::HTTPS;
|
||||
}
|
||||
}
|
||||
|
||||
if parts.extract::<Extension<TlsState>>().await.is_ok() {
|
||||
http::uri::Scheme::HTTPS
|
||||
} else {
|
||||
http::uri::Scheme::HTTP
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn scheme_middleware(request: Request, next: Next) -> Response {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
Span::current().record("scheme", scheme.to_string());
|
||||
parts.extensions.insert::<Scheme>(Scheme(scheme));
|
||||
|
||||
let request = Request::from_parts(parts, body);
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use axum::{body::Body, http::Request};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_proto_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-proto", "https")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_scheme_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-scheme", "https")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forwarded_header_trusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("forwarded", "proto=https")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn x_forwarded_proto_untrusted() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-proto", "https")
|
||||
.extension(TrustedProxy(false))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTP,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scheme_from_tls_state() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.extension(TlsState {
|
||||
peer_certificates: None,
|
||||
})
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn scheme_defaults_to_http() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTP,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn priority_order() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-proto", "http")
|
||||
.header("x-forwarded-scheme", "https")
|
||||
.header("forwarded", "proto=https")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTP,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn multiple_forwarded_stanzas() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("forwarded", "proto=http, proto=https")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTP,);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_scheme_case_insensitive() {
|
||||
let (mut parts, _) = Request::builder()
|
||||
.uri("http://example.com/path")
|
||||
.header("x-forwarded-proto", "HTTPS")
|
||||
.extension(TrustedProxy(true))
|
||||
.body(Body::empty())
|
||||
.expect("Failed to create request")
|
||||
.into_parts();
|
||||
|
||||
let scheme = extract_scheme(&mut parts).await;
|
||||
|
||||
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
|
||||
}
|
||||
}
|
||||
59
src/axum/extract/trusted_proxy.rs
Normal file
59
src/axum/extract/trusted_proxy.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{ConnectInfo, FromRequestParts, Request},
|
||||
http::request::Parts,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::{instrument, trace};
|
||||
|
||||
use crate::config;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) struct TrustedProxy(pub bool);
|
||||
|
||||
impl<S> FromRequestParts<S> for TrustedProxy
|
||||
where S: Send + Sync
|
||||
{
|
||||
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
|
||||
Extension::<Self>::from_request_parts(parts, state)
|
||||
.await
|
||||
.map(|Extension(trusted_proxy)| trusted_proxy)
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn extract_trusted_proxy(parts: &mut Parts) -> bool {
|
||||
if let Ok(ConnectInfo(addr)) = parts.extract::<ConnectInfo<SocketAddr>>().await {
|
||||
let trusted_proxy_cidrs = &config::get().listen.trusted_proxy_cidrs;
|
||||
|
||||
for trusted_net in trusted_proxy_cidrs {
|
||||
if trusted_net.contains(&addr.ip()) {
|
||||
trace!(
|
||||
?addr,
|
||||
?trusted_net,
|
||||
"connection is now considered coming from a trusted proxy"
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub(crate) async fn trusted_proxy_middleware(request: Request, next: Next) -> Response {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
let trusted_proxy = extract_trusted_proxy(&mut parts).await;
|
||||
parts
|
||||
.extensions
|
||||
.insert::<TrustedProxy>(TrustedProxy(trusted_proxy));
|
||||
|
||||
let request = Request::from_parts(parts, body);
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
6
src/axum/mod.rs
Normal file
6
src/axum/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub(crate) mod accept;
|
||||
pub(crate) mod error;
|
||||
pub(crate) mod extract;
|
||||
pub(crate) mod router;
|
||||
pub(crate) mod server;
|
||||
pub(crate) mod trace;
|
||||
39
src/axum/router.rs
Normal file
39
src/axum/router.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use axum::{Router, http::status::StatusCode, middleware::from_fn};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
|
||||
use crate::{
|
||||
axum::{
|
||||
extract::{
|
||||
client_ip::client_ip_middleware, host::host_middleware, scheme::scheme_middleware,
|
||||
trusted_proxy::trusted_proxy_middleware,
|
||||
},
|
||||
trace::{span_middleware, tracing_middleware},
|
||||
},
|
||||
config,
|
||||
};
|
||||
|
||||
#[inline]
|
||||
pub(crate) fn wrap_router(router: Router, with_trace: bool) -> Router {
|
||||
let config = config::get();
|
||||
let timeout = durstr::parse(&config.web.timeout_http_read_header)
|
||||
.expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_read).expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_write).expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_idle).expect("Invalid duration in http timeout");
|
||||
let service_builder = ServiceBuilder::new()
|
||||
.layer(TimeoutLayer::with_status_code(
|
||||
StatusCode::REQUEST_TIMEOUT,
|
||||
timeout,
|
||||
))
|
||||
.layer(from_fn(span_middleware))
|
||||
.layer(from_fn(trusted_proxy_middleware))
|
||||
.layer(from_fn(client_ip_middleware))
|
||||
.layer(from_fn(scheme_middleware))
|
||||
.layer(from_fn(host_middleware));
|
||||
if with_trace {
|
||||
router.layer(service_builder.layer(from_fn(tracing_middleware)))
|
||||
} else {
|
||||
router.layer(service_builder)
|
||||
}
|
||||
}
|
||||
119
src/axum/server.rs
Normal file
119
src/axum/server.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use std::{net, os::unix};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::{
|
||||
Handle,
|
||||
accept::DefaultAcceptor,
|
||||
tls_rustls::{RustlsAcceptor, RustlsConfig},
|
||||
};
|
||||
use eyre::Result;
|
||||
use tracing::info;
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Tasks},
|
||||
axum::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor},
|
||||
};
|
||||
|
||||
async fn run_plain(
|
||||
arbiter: Arbiter,
|
||||
name: &str,
|
||||
router: Router,
|
||||
addr: net::SocketAddr,
|
||||
) -> Result<()> {
|
||||
info!(addr = addr.to_string(), "starting {name} server");
|
||||
|
||||
let handle = Handle::new();
|
||||
arbiter.add_net_handle(handle.clone()).await;
|
||||
|
||||
axum_server::Server::bind(addr)
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start_plain(
|
||||
tasks: &mut Tasks,
|
||||
name: &'static str,
|
||||
router: Router,
|
||||
addr: net::SocketAddr,
|
||||
) -> Result<()> {
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::run_plain({name}, {addr})", module_path!()))
|
||||
.spawn(run_plain(arbiter, name, router, addr))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn run_unix(
|
||||
arbiter: Arbiter,
|
||||
name: &str,
|
||||
router: Router,
|
||||
addr: unix::net::SocketAddr,
|
||||
) -> Result<()> {
|
||||
info!(addr = ?addr, "starting {name} server");
|
||||
|
||||
let handle = Handle::new();
|
||||
arbiter.add_unix_handle(handle.clone()).await;
|
||||
|
||||
axum_server::Server::bind(addr)
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start_unix(
|
||||
tasks: &mut Tasks,
|
||||
name: &'static str,
|
||||
router: Router,
|
||||
addr: unix::net::SocketAddr,
|
||||
) -> Result<()> {
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::run_unix({name}, {addr:?})", module_path!()))
|
||||
.spawn(run_unix(arbiter, name, router, addr))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_tls(
|
||||
arbiter: Arbiter,
|
||||
name: &str,
|
||||
router: Router,
|
||||
addr: net::SocketAddr,
|
||||
config: RustlsConfig,
|
||||
) -> Result<()> {
|
||||
info!(addr = addr.to_string(), "starting {name} server");
|
||||
|
||||
let handle = Handle::new();
|
||||
arbiter.add_net_handle(handle.clone()).await;
|
||||
|
||||
axum_server::Server::bind(addr)
|
||||
.acceptor(TlsAcceptor::new(RustlsAcceptor::new(config).acceptor(
|
||||
ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()),
|
||||
)))
|
||||
.handle(handle)
|
||||
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn start_tls(
|
||||
tasks: &mut Tasks,
|
||||
name: &'static str,
|
||||
router: Router,
|
||||
addr: net::SocketAddr,
|
||||
config: RustlsConfig,
|
||||
) -> Result<()> {
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::run_tls({name}, {addr})", module_path!()))
|
||||
.spawn(run_tls(arbiter, name, router, addr, config))?;
|
||||
Ok(())
|
||||
}
|
||||
48
src/axum/trace.rs
Normal file
48
src/axum/trace.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{Instrument as _, field, info, info_span, trace};
|
||||
|
||||
use crate::config;
|
||||
|
||||
pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
|
||||
let config = config::get();
|
||||
let http_headers = request
|
||||
.headers()
|
||||
.iter()
|
||||
.filter(|(name, _)| {
|
||||
for header in &config.log.http_headers {
|
||||
if header.eq_ignore_ascii_case(name.as_str()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
})
|
||||
.map(|(name, value)| (name.to_string().to_lowercase().replace('-', "_"), value))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let span = info_span!(
|
||||
"request",
|
||||
path = %request.uri(),
|
||||
method = %request.method(),
|
||||
remote = field::Empty,
|
||||
scheme = field::Empty,
|
||||
host = field::Empty,
|
||||
http_headers = ?http_headers,
|
||||
);
|
||||
next.run(request).instrument(span).await
|
||||
}
|
||||
|
||||
pub(crate) async fn tracing_middleware(request: Request, next: Next) -> Response {
|
||||
let event = request.uri().clone();
|
||||
trace!("request start");
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).await;
|
||||
let runtime = start.elapsed();
|
||||
let status = response.status().as_u16();
|
||||
|
||||
info!(status = status, runtime = runtime.as_millis(), "{event}");
|
||||
|
||||
response
|
||||
}
|
||||
1
src/brands/mod.rs
Normal file
1
src/brands/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub(crate) mod tls;
|
||||
134
src/brands/tls.rs
Normal file
134
src/brands/tls.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use std::{
|
||||
collections::{HashMap, hash_map::Entry},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use eyre::{Report, Result};
|
||||
use rustls::{
|
||||
RootCertStore,
|
||||
crypto::CryptoProvider,
|
||||
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
|
||||
server::ClientHello,
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
|
||||
use crate::db;
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Brand {
|
||||
domain: String,
|
||||
default: bool,
|
||||
web_certificate: Arc<CertifiedKey>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CertResolver {
|
||||
brands: Vec<Brand>,
|
||||
}
|
||||
|
||||
impl CertResolver {
|
||||
pub(crate) fn resolve(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
let server_name = client_hello.server_name()?;
|
||||
let mut best = None;
|
||||
|
||||
for brand in &self.brands {
|
||||
if best.is_none() && brand.default {
|
||||
best = Some(Arc::clone(&brand.web_certificate));
|
||||
}
|
||||
if server_name == brand.domain || server_name.ends_with(&format!(".{}", brand.domain)) {
|
||||
best = Some(Arc::clone(&brand.web_certificate));
|
||||
}
|
||||
}
|
||||
|
||||
best
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn make_cert_managers() -> Result<(CertResolver, RootCertStore)> {
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct BrandRow {
|
||||
brand_uuid: uuid::Uuid,
|
||||
domain: String,
|
||||
default: bool,
|
||||
web_cert_data: Option<String>,
|
||||
web_cert_key: Option<String>,
|
||||
client_cert_data: Option<String>,
|
||||
}
|
||||
|
||||
let rows = sqlx::query_as::<_, BrandRow>(
|
||||
"
|
||||
SELECT
|
||||
b.brand_uuid,
|
||||
b.domain,
|
||||
b.default,
|
||||
wc.certificate_data AS web_cert_data,
|
||||
wc.key_data AS web_cert_key,
|
||||
cc.certificate_data AS client_cert_data
|
||||
FROM authentik_brands_brand b
|
||||
LEFT JOIN authentik_crypto_certificatekeypair wc
|
||||
ON wc.kp_uuid = b.web_certificate_id
|
||||
LEFT JOIN authentik_brands_brand_client_certificates bcc
|
||||
ON bcc.brand_id = b.brand_uuid
|
||||
LEFT JOIN authentik_crypto_certificatekeypair cc
|
||||
ON cc.kp_uuid = bcc.certificatekeypair_id
|
||||
",
|
||||
)
|
||||
.fetch_all(db::get())
|
||||
.await?;
|
||||
|
||||
let (brands, roots) = tokio::task::spawn_blocking(|| {
|
||||
let mut brands = HashMap::new();
|
||||
let mut roots = RootCertStore::empty();
|
||||
|
||||
for row in rows {
|
||||
let BrandRow {
|
||||
brand_uuid,
|
||||
domain,
|
||||
default,
|
||||
web_cert_data,
|
||||
web_cert_key,
|
||||
client_cert_data,
|
||||
} = row;
|
||||
|
||||
if let (Some(certificate_data), Some(key_data)) = (web_cert_data, web_cert_key)
|
||||
&& let Entry::Vacant(e) = brands.entry(brand_uuid)
|
||||
{
|
||||
let brand = Brand {
|
||||
domain,
|
||||
default,
|
||||
web_certificate: {
|
||||
let cert_chain =
|
||||
CertificateDer::pem_reader_iter(certificate_data.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let key_der = PrivateKeyDer::from_pem_reader(key_data.as_bytes())?;
|
||||
let provider =
|
||||
CryptoProvider::get_default().expect("no rustls provider installed");
|
||||
Arc::new(CertifiedKey::new(
|
||||
cert_chain,
|
||||
provider.key_provider.load_private_key(key_der)?,
|
||||
))
|
||||
},
|
||||
};
|
||||
e.insert(brand);
|
||||
}
|
||||
|
||||
if let Some(certificate_data) = client_cert_data {
|
||||
let cert_chain = CertificateDer::pem_reader_iter(certificate_data.as_bytes())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
for cert in cert_chain {
|
||||
roots.add(cert)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<_, Report>((brands, roots))
|
||||
})
|
||||
.await??;
|
||||
|
||||
Ok((
|
||||
CertResolver {
|
||||
brands: brands.into_values().collect(),
|
||||
},
|
||||
roots,
|
||||
))
|
||||
}
|
||||
244
src/config/mod.rs
Normal file
244
src/config/mod.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
use std::{
|
||||
env,
|
||||
fs::{self, read_to_string},
|
||||
path::PathBuf,
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
|
||||
use arc_swap::{ArcSwap, Guard};
|
||||
use eyre::Result;
|
||||
use notify::{RecommendedWatcher, Watcher as _};
|
||||
use serde_json::{Map, Value};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub(crate) mod schema;
|
||||
|
||||
pub(crate) use schema::Config;
|
||||
use url::Url;
|
||||
|
||||
use crate::arbiter::{Arbiter, Tasks};
|
||||
|
||||
static DEFAULT_CONFIG: &str = include_str!("../../authentik/lib/default.yml");
|
||||
static CONFIG_MANAGER: OnceLock<ConfigManager> = OnceLock::new();
|
||||
|
||||
fn config_paths() -> Vec<PathBuf> {
|
||||
let mut config_paths = vec![
|
||||
PathBuf::from("/etc/authentik/config.yml"),
|
||||
PathBuf::from(""),
|
||||
];
|
||||
if let Ok(workspace) = env::var("WORKSPACE_DIR") {
|
||||
let _ = env::set_current_dir(workspace);
|
||||
}
|
||||
|
||||
if let Ok(paths) = glob::glob("/etc/authentik/config.d/*.yml") {
|
||||
config_paths.extend(paths.filter_map(Result::ok));
|
||||
}
|
||||
|
||||
let environment = env::var("AUTHENTIK_ENV").unwrap_or_else(|_| "local".to_owned());
|
||||
|
||||
let mut computed_paths = Vec::new();
|
||||
|
||||
for path in config_paths {
|
||||
if let Ok(metadata) = fs::metadata(&path) {
|
||||
if !metadata.is_dir() {
|
||||
computed_paths.push(path);
|
||||
}
|
||||
} else {
|
||||
let env_paths = vec![
|
||||
path.join(format!("{environment}.yml")),
|
||||
path.join(format!("{environment}.env.yml")),
|
||||
];
|
||||
for env_path in env_paths {
|
||||
if let Ok(metadata) = fs::metadata(&env_path)
|
||||
&& !metadata.is_dir()
|
||||
{
|
||||
computed_paths.push(env_path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
computed_paths
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn load_raw(config_paths: &[PathBuf]) -> Result<Value> {
|
||||
let mut builder = config::Config::builder().add_source(config::File::from_str(
|
||||
DEFAULT_CONFIG,
|
||||
config::FileFormat::Yaml,
|
||||
));
|
||||
for path in config_paths {
|
||||
builder = builder
|
||||
.add_source(config::File::from(path.as_path()).format(config::FileFormat::Yaml));
|
||||
}
|
||||
builder = builder.add_source(config::Environment::with_prefix("AUTHENTIK"));
|
||||
let config = builder.build()?;
|
||||
let raw = config.try_deserialize::<Value>()?;
|
||||
Ok(raw)
|
||||
}
|
||||
|
||||
fn expand_value(value: &str) -> (String, Option<PathBuf>) {
|
||||
let value = value.trim();
|
||||
if let Ok(uri) = Url::parse(value) {
|
||||
let fallback = uri.query().unwrap_or("").to_owned();
|
||||
match uri.scheme() {
|
||||
"file" => {
|
||||
let path = uri.path();
|
||||
match read_to_string(path).map(|s| s.trim().to_owned()) {
|
||||
Ok(value) => return (value, Some(PathBuf::from(path))),
|
||||
Err(err) => {
|
||||
error!("failed to read config value from {path}: {err}");
|
||||
return (fallback, Some(PathBuf::from(path)));
|
||||
}
|
||||
}
|
||||
}
|
||||
"env" => {
|
||||
if let Some(var) = uri.host_str() {
|
||||
if let Ok(value) = env::var(var) {
|
||||
return (value, None);
|
||||
}
|
||||
return (fallback, None);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
(value.to_owned(), None)
|
||||
}
|
||||
|
||||
fn expand(mut raw: Value) -> (Value, Vec<PathBuf>) {
|
||||
let mut file_paths = Vec::new();
|
||||
let value = match &mut raw {
|
||||
Value::String(s) => {
|
||||
let (v, path) = Self::expand_value(s);
|
||||
if let Some(path) = path {
|
||||
file_paths.push(path);
|
||||
}
|
||||
Value::String(v)
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
let mut res = Vec::with_capacity(arr.len());
|
||||
for v in arr {
|
||||
let (expanded, paths) = Self::expand(v.clone());
|
||||
file_paths.extend(paths);
|
||||
res.push(expanded);
|
||||
}
|
||||
Value::Array(res)
|
||||
}
|
||||
Value::Object(map) => {
|
||||
let mut res = Map::with_capacity(map.len());
|
||||
for (k, v) in map {
|
||||
let (expanded, paths) = Self::expand(v.clone());
|
||||
file_paths.extend(paths);
|
||||
res.insert(k.clone(), expanded);
|
||||
}
|
||||
Value::Object(res)
|
||||
}
|
||||
_ => raw,
|
||||
};
|
||||
(value, file_paths)
|
||||
}
|
||||
|
||||
fn load(config_paths: &[PathBuf]) -> Result<(Self, Vec<PathBuf>)> {
|
||||
let raw = Self::load_raw(config_paths)?;
|
||||
let (expanded, file_paths) = Self::expand(raw);
|
||||
let config: Self = serde_json::from_value(expanded)?;
|
||||
Ok((config, file_paths))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ConfigManager {
|
||||
config: ArcSwap<Config>,
|
||||
config_paths: Vec<PathBuf>,
|
||||
watch_paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl ConfigManager {
|
||||
pub(crate) fn init() -> Result<()> {
|
||||
info!("loading config");
|
||||
let config_paths = config_paths();
|
||||
let mut watch_paths = config_paths.clone();
|
||||
let (config, other_paths) = Config::load(&config_paths)?;
|
||||
watch_paths.extend(other_paths);
|
||||
let manager = Self {
|
||||
config: ArcSwap::from_pointee(config),
|
||||
config_paths,
|
||||
watch_paths,
|
||||
};
|
||||
CONFIG_MANAGER.get_or_init(|| manager);
|
||||
info!("config loaded");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn run(tasks: &mut Tasks) -> Result<()> {
|
||||
info!("starting config file watcher");
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_config", module_path!()))
|
||||
.spawn(watch_config(arbiter))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn watch_config(arbiter: Arbiter) -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(100);
|
||||
let mut watcher = RecommendedWatcher::new(
|
||||
move |res: notify::Result<notify::Event>| {
|
||||
if let Ok(event) = res
|
||||
&& let notify::EventKind::Modify(_) = &event.kind
|
||||
{
|
||||
let _ = tx.blocking_send(());
|
||||
}
|
||||
},
|
||||
notify::Config::default(),
|
||||
)?;
|
||||
let watch_paths = &CONFIG_MANAGER
|
||||
.get()
|
||||
.expect("failed to get config, has it been initialized?")
|
||||
.watch_paths;
|
||||
for path in watch_paths {
|
||||
watcher.watch(path.as_ref(), notify::RecursiveMode::NonRecursive)?;
|
||||
}
|
||||
|
||||
info!("config file watcher started on paths: {:?}", watch_paths);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = rx.recv() => {
|
||||
info!("a configuration file changed, reloading config");
|
||||
if res.is_none() {
|
||||
break;
|
||||
}
|
||||
let manager = CONFIG_MANAGER.get().expect("failed to get config, has it been initialized?");
|
||||
match tokio::task::spawn_blocking(|| Config::load(&manager.config_paths)).await? {
|
||||
Ok((new_config, _)) => {
|
||||
info!("configuration reloaded");
|
||||
manager.config.store(Arc::new(new_config));
|
||||
if let Err(err) = arbiter.config_changed_send(()) {
|
||||
warn!("failed to notify of config change, aborting: {err:?}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("failed to reload config, continuing with previous config: {err:?}");
|
||||
}
|
||||
}
|
||||
},
|
||||
() = arbiter.shutdown() => break,
|
||||
}
|
||||
}
|
||||
|
||||
info!("stopping config file watcher");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get() -> Guard<Arc<Config>> {
|
||||
let manager = CONFIG_MANAGER
|
||||
.get()
|
||||
.expect("failed to get config, has it been initialized?");
|
||||
manager.config.load()
|
||||
}
|
||||
146
src/config/schema.rs
Normal file
146
src/config/schema.rs
Normal file
@@ -0,0 +1,146 @@
|
||||
use std::{collections::HashMap, net::SocketAddr, num::NonZeroUsize, path::PathBuf};
|
||||
|
||||
use ipnet::IpNet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct Config {
|
||||
pub(crate) postgresql: PostgreSQLConfig,
|
||||
|
||||
pub(crate) listen: ListenConfig,
|
||||
|
||||
pub(crate) http_timeout: u32,
|
||||
|
||||
pub(crate) debug: bool,
|
||||
#[serde(default)]
|
||||
pub(crate) secret_key: String,
|
||||
|
||||
pub(crate) log_level: String,
|
||||
pub(crate) log: LogConfig,
|
||||
|
||||
pub(crate) error_reporting: ErrorReportingConfig,
|
||||
|
||||
pub(crate) outposts: OutpostsConfig,
|
||||
|
||||
pub(crate) cookie_domain: Option<String>,
|
||||
|
||||
pub(crate) compliance: ComplianceConfig,
|
||||
|
||||
pub(crate) blueprints_dir: PathBuf,
|
||||
pub(crate) cert_discovery_dir: PathBuf,
|
||||
|
||||
pub(crate) web: WebConfig,
|
||||
|
||||
pub(crate) worker: WorkerConfig,
|
||||
|
||||
pub(crate) storage: StorageConfig,
|
||||
|
||||
// Outpost specific config
|
||||
// These are only relevant for outposts, and cannot be set via YAML
|
||||
// They are loaded via this config loader to support file:// schemas
|
||||
pub(crate) authentik_host: Option<String>,
|
||||
pub(crate) authentik_host_browser: Option<String>,
|
||||
pub(crate) authentik_token: Option<String>,
|
||||
pub(crate) authentik_insecure: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct PostgreSQLConfig {
|
||||
pub(crate) host: String,
|
||||
pub(crate) port: u16,
|
||||
pub(crate) user: String,
|
||||
pub(crate) password: String,
|
||||
pub(crate) name: String,
|
||||
|
||||
pub(crate) sslmode: String,
|
||||
pub(crate) sslrootcert: Option<String>,
|
||||
pub(crate) sslcert: Option<String>,
|
||||
pub(crate) sslkey: Option<String>,
|
||||
|
||||
pub(crate) conn_max_age: Option<u64>,
|
||||
pub(crate) conn_health_checks: bool,
|
||||
|
||||
pub(crate) default_schema: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ListenConfig {
|
||||
pub(crate) http: Vec<SocketAddr>,
|
||||
pub(crate) https: Vec<SocketAddr>,
|
||||
pub(crate) ldap: Vec<SocketAddr>,
|
||||
pub(crate) ldaps: Vec<SocketAddr>,
|
||||
pub(crate) radius: Vec<SocketAddr>,
|
||||
pub(crate) metrics: Vec<SocketAddr>,
|
||||
pub(crate) debug: SocketAddr,
|
||||
pub(crate) trusted_proxy_cidrs: Vec<IpNet>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct LogConfig {
|
||||
pub(crate) http_headers: Vec<String>,
|
||||
pub(crate) rust_log: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ErrorReportingConfig {
|
||||
pub(crate) enabled: bool,
|
||||
pub(crate) sentry_dsn: Option<String>,
|
||||
pub(crate) environment: String,
|
||||
pub(crate) send_pii: bool,
|
||||
pub(crate) sample_rate: f32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct OutpostsConfig {
|
||||
pub(crate) disable_embedded_outpost: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ComplianceConfig {
|
||||
pub(crate) fips: ComplianceFipsConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct ComplianceFipsConfig {
|
||||
pub(crate) enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct WebConfig {
|
||||
pub(crate) workers: usize,
|
||||
pub(crate) threads: usize,
|
||||
pub(crate) path: String,
|
||||
pub(crate) timeout_http_read_header: String,
|
||||
pub(crate) timeout_http_read: String,
|
||||
pub(crate) timeout_http_write: String,
|
||||
pub(crate) timeout_http_idle: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct WorkerConfig {
|
||||
pub(crate) processes: NonZeroUsize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct StorageConfig {
|
||||
pub(crate) backend: String,
|
||||
pub(crate) file: StorageFileConfig,
|
||||
pub(crate) media: Option<StorageOverrideConfig>,
|
||||
pub(crate) reports: Option<StorageOverrideConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct StorageFileConfig {
|
||||
pub(crate) path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct StorageOverrideConfig {
|
||||
pub(crate) backend: Option<String>,
|
||||
pub(crate) file: Option<StorageFileOverrideConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
|
||||
pub(crate) struct StorageFileOverrideConfig {
|
||||
pub(crate) path: Option<PathBuf>,
|
||||
}
|
||||
110
src/db.rs
Normal file
110
src/db.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::{str::FromStr as _, sync::OnceLock, time::Duration};
|
||||
|
||||
use eyre::Result;
|
||||
use sqlx::{
|
||||
Executor as _, PgPool,
|
||||
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
|
||||
};
|
||||
use tracing::{info, log::LevelFilter, trace};
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Tasks},
|
||||
authentik_full_version, config,
|
||||
mode::Mode,
|
||||
};
|
||||
|
||||
static DB: OnceLock<PgPool> = OnceLock::new();
|
||||
|
||||
fn get_connect_opts() -> Result<PgConnectOptions> {
|
||||
let config = config::get();
|
||||
let mut opts = PgConnectOptions::new()
|
||||
.application_name(&format!(
|
||||
"authentik-{}@{}",
|
||||
Mode::get(),
|
||||
authentik_full_version()
|
||||
))
|
||||
.host(&config.postgresql.host)
|
||||
.port(config.postgresql.port)
|
||||
.username(&config.postgresql.user)
|
||||
.password(&config.postgresql.password)
|
||||
.database(&config.postgresql.name)
|
||||
.ssl_mode(PgSslMode::from_str(&config.postgresql.sslmode)?);
|
||||
if let Some(sslrootcert) = &config.postgresql.sslrootcert {
|
||||
opts = opts.ssl_root_cert_from_pem(sslrootcert.as_bytes().to_vec());
|
||||
}
|
||||
if let Some(sslcert) = &config.postgresql.sslcert {
|
||||
opts = opts.ssl_client_cert_from_pem(sslcert.as_bytes());
|
||||
}
|
||||
if let Some(sslkey) = &config.postgresql.sslkey {
|
||||
opts = opts.ssl_client_key_from_pem(sslkey.as_bytes());
|
||||
}
|
||||
Ok(opts)
|
||||
}
|
||||
|
||||
async fn update_connect_opts_on_config_change(arbiter: Arbiter) -> Result<()> {
|
||||
let mut config_changed_rx = arbiter.config_changed_subscribe();
|
||||
info!("starting database watcher for config changes");
|
||||
loop {
|
||||
tokio::select! {
|
||||
res = config_changed_rx.changed() => {
|
||||
if let Err(err) = res {
|
||||
trace!("error receiving config changes: {err:?}");
|
||||
break;
|
||||
}
|
||||
trace!("config change received, refreshing database connection options");
|
||||
let db = get();
|
||||
db.set_connect_options(get_connect_opts()?);
|
||||
},
|
||||
() = arbiter.shutdown() => break,
|
||||
}
|
||||
}
|
||||
|
||||
info!("stopping database watcher for config changes");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn init(tasks: &mut Tasks) -> Result<()> {
|
||||
info!("initializing database pool");
|
||||
let options = get_connect_opts()?;
|
||||
let config = config::get();
|
||||
|
||||
let pool_options = PgPoolOptions::new()
|
||||
.min_connections(1)
|
||||
.max_connections(4)
|
||||
.acquire_time_level(LevelFilter::Trace)
|
||||
.max_lifetime(config.postgresql.conn_max_age.map(Duration::from_secs))
|
||||
.test_before_acquire(config.postgresql.conn_health_checks)
|
||||
.after_connect(|conn, _meta| {
|
||||
Box::pin(async move {
|
||||
let application_name =
|
||||
format!("authentik-{}@{}", Mode::get(), authentik_full_version());
|
||||
let default_schema = &config::get().postgresql.default_schema;
|
||||
let query = format!(
|
||||
"SET application_name = '{application_name}'; SET search_path = \
|
||||
'{default_schema}';"
|
||||
);
|
||||
conn.execute(query.as_str()).await?;
|
||||
Ok(())
|
||||
})
|
||||
});
|
||||
|
||||
let pool = pool_options.connect_with(options).await?;
|
||||
DB.get_or_init(|| pool);
|
||||
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!(
|
||||
"{}::update_connect_opts_on_config_change",
|
||||
module_path!(),
|
||||
))
|
||||
.spawn(update_connect_opts_on_config_change(arbiter))?;
|
||||
|
||||
info!("database pool initialized");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn get() -> &'static PgPool {
|
||||
DB.get()
|
||||
.expect("failed to get db, has it been initialized?")
|
||||
}
|
||||
220
src/main.rs
Normal file
220
src/main.rs
Normal file
@@ -0,0 +1,220 @@
|
||||
use std::{
|
||||
process::exit,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use ::tracing::{error, info, trace};
|
||||
use argh::FromArgs;
|
||||
use eyre::{Result, eyre};
|
||||
|
||||
use crate::{arbiter::Tasks, config::ConfigManager, mode::Mode};
|
||||
|
||||
mod arbiter;
|
||||
mod axum;
|
||||
#[cfg(feature = "core")]
|
||||
mod brands;
|
||||
mod config;
|
||||
#[cfg(feature = "core")]
|
||||
mod db;
|
||||
mod metrics;
|
||||
mod mode;
|
||||
#[cfg(feature = "proxy")]
|
||||
mod proxy;
|
||||
#[cfg(feature = "core")]
|
||||
mod server;
|
||||
mod tokio;
|
||||
mod tracing;
|
||||
#[cfg(feature = "core")]
|
||||
mod worker;
|
||||
|
||||
const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
pub(crate) fn authentik_build_hash(fallback: Option<String>) -> String {
|
||||
std::env::var("GIT_BUILD_HASH").unwrap_or_else(|_| fallback.unwrap_or_default())
|
||||
}
|
||||
|
||||
pub(crate) fn authentik_full_version() -> String {
|
||||
let build_hash = authentik_build_hash(None);
|
||||
if build_hash.is_empty() {
|
||||
VERSION.to_owned()
|
||||
} else {
|
||||
format!("{VERSION}+{build_hash}")
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn authentik_user_agent() -> String {
|
||||
format!("authentik@{}", authentik_full_version())
|
||||
}
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
/// The authentication glue you need.
|
||||
struct Cli {
|
||||
#[argh(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
#[argh(subcommand)]
|
||||
enum Command {
|
||||
#[cfg(feature = "core")]
|
||||
AllInOne(AllInOne),
|
||||
#[cfg(feature = "core")]
|
||||
Server(server::Cli),
|
||||
#[cfg(feature = "core")]
|
||||
Worker(worker::Cli),
|
||||
#[cfg(feature = "proxy")]
|
||||
Proxy(proxy::Cli),
|
||||
#[cfg(feature = "core")]
|
||||
Manage(Manage),
|
||||
}
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
/// Run the authentik server and worker.
|
||||
#[argh(subcommand, name = "allinone")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
struct AllInOne {}
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
/// authentik django's management command.
|
||||
#[argh(subcommand, name = "manage")]
|
||||
struct Manage {
|
||||
#[argh(positional, greedy)]
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let tracing_crude = tracing::install_crude();
|
||||
info!(version = env!("CARGO_PKG_VERSION"), "authentik is starting");
|
||||
|
||||
let cli: Cli = argh::from_env();
|
||||
|
||||
match &cli.command {
|
||||
#[cfg(feature = "core")]
|
||||
Command::AllInOne(_) => Mode::set(Mode::AllInOne)?,
|
||||
#[cfg(feature = "core")]
|
||||
Command::Server(_) => Mode::set(Mode::Server)?,
|
||||
#[cfg(feature = "core")]
|
||||
Command::Worker(_) => Mode::set(Mode::Worker)?,
|
||||
#[cfg(feature = "proxy")]
|
||||
Command::Proxy(_) => Mode::set(Mode::Proxy)?,
|
||||
#[cfg(feature = "core")]
|
||||
Command::Manage(args) => {
|
||||
let mut process = std::process::Command::new("python")
|
||||
.args(["-m", "manage"])
|
||||
.args(&args.args)
|
||||
.spawn()?;
|
||||
let status = process.wait()?;
|
||||
if let Some(code) = status.code() {
|
||||
exit(code);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
trace!("installing error formatting");
|
||||
color_eyre::install()?;
|
||||
|
||||
trace!("installing rustls crypto provider");
|
||||
#[expect(
|
||||
clippy::unwrap_in_result,
|
||||
reason = "result type does not implement Error"
|
||||
)]
|
||||
rustls::crypto::aws_lc_rs::default_provider()
|
||||
.install_default()
|
||||
.expect("Failed to install rustls provider");
|
||||
|
||||
#[cfg(feature = "core")]
|
||||
if Mode::is_core() {
|
||||
if std::env::var("PROMETHEUS_MULTIPROC_DIR").is_err() {
|
||||
let dir = std::env::temp_dir().join("authentik_prometheus_tmp");
|
||||
std::fs::create_dir_all(&dir)?;
|
||||
#[expect(unsafe_code, reason = "see safety comment below")]
|
||||
// SAFETY: there is only one thread at this point, so this is safe.
|
||||
unsafe {
|
||||
std::env::set_var("PROMETHEUS_MULTIPROC_DIR", dir);
|
||||
}
|
||||
trace!(
|
||||
env = std::env::var("PROMETHEUS_MULTIPROC_DIR").unwrap_or_default(),
|
||||
"setting PROMETHEUS_MULTIPROC_DIR"
|
||||
);
|
||||
} else {
|
||||
trace!("PROMETHEUS_MULTIPROC_DIR already set");
|
||||
}
|
||||
|
||||
trace!("initializing Python");
|
||||
pyo3::Python::initialize();
|
||||
trace!("Python initialized");
|
||||
}
|
||||
|
||||
ConfigManager::init()?;
|
||||
|
||||
let _sentry = config::get()
|
||||
.error_reporting
|
||||
.enabled
|
||||
.then(tracing::sentry::install);
|
||||
|
||||
tracing::install()?;
|
||||
drop(tracing_crude);
|
||||
|
||||
::tokio::runtime::Builder::new_multi_thread()
|
||||
.thread_name_fn(|| {
|
||||
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
|
||||
format!("tokio-{id}")
|
||||
})
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
let mut tasks = Tasks::new()?;
|
||||
|
||||
ConfigManager::run(&mut tasks)?;
|
||||
|
||||
let metrics = metrics::run(&mut tasks)?;
|
||||
|
||||
#[cfg(feature = "core")]
|
||||
if Mode::is_core() {
|
||||
db::init(&mut tasks).await?;
|
||||
}
|
||||
|
||||
match cli.command {
|
||||
#[cfg(feature = "core")]
|
||||
Command::AllInOne(_) => {
|
||||
let workers = worker::run(worker::Cli::default(), &mut tasks)?;
|
||||
metrics.workers.store(Some(Arc::clone(&workers)));
|
||||
let server = server::run(server::Cli::default(), &mut tasks)?;
|
||||
server.workers.store(Some(workers));
|
||||
metrics.server.store(Some(server));
|
||||
}
|
||||
#[cfg(feature = "core")]
|
||||
Command::Server(args) => {
|
||||
let server = server::run(args, &mut tasks)?;
|
||||
metrics.server.store(Some(server));
|
||||
}
|
||||
#[cfg(feature = "core")]
|
||||
Command::Worker(args) => {
|
||||
let workers = worker::run(args, &mut tasks)?;
|
||||
metrics.workers.store(Some(workers));
|
||||
}
|
||||
#[cfg(feature = "proxy")]
|
||||
Command::Proxy(args) => proxy::run(args, &mut tasks)?,
|
||||
#[cfg(feature = "core")]
|
||||
Command::Manage(_) => unreachable!(),
|
||||
}
|
||||
|
||||
let errors = tasks.run().await;
|
||||
|
||||
if errors.is_empty() {
|
||||
info!("authentik exiting");
|
||||
Ok(())
|
||||
} else {
|
||||
error!("authentik encountered errors: {:?}", errors);
|
||||
Err(eyre!("Errors encountered: {:?}", errors))
|
||||
}
|
||||
})
|
||||
}
|
||||
91
src/metrics/handlers.rs
Normal file
91
src/metrics/handlers.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{body::Body, extract::State, http::StatusCode, response::Response};
|
||||
|
||||
#[cfg(feature = "core")]
|
||||
use crate::mode::Mode;
|
||||
use crate::{axum::error::Result, metrics::Metrics};
|
||||
|
||||
pub(super) async fn metrics_handler(State(state): State<Arc<Metrics>>) -> Result<Response> {
|
||||
let mut metrics = Vec::new();
|
||||
state.prometheus.render_to_write(&mut metrics)?;
|
||||
|
||||
#[cfg(feature = "core")]
|
||||
if Mode::is_core() {
|
||||
use axum::http::{Request, header::HOST};
|
||||
|
||||
if [Mode::AllInOne, Mode::Server].contains(&Mode::get()) {
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("http://localhost:8000/-/metrics/")
|
||||
.header(HOST, "localhost")
|
||||
.body(Body::from(""));
|
||||
if let Ok(req) = req
|
||||
&& let Some(server) = state.server.load_full()
|
||||
{
|
||||
let _ = server.client.request(req).await;
|
||||
}
|
||||
} else if [Mode::Worker].contains(&Mode::get()) {
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("http://localhost:8000/-/metrics/")
|
||||
.header(HOST, "localhost")
|
||||
.body(Body::from(""));
|
||||
if let Ok(req) = req
|
||||
&& let Some(workers) = state.workers.load_full()
|
||||
{
|
||||
let _ = workers.client.request(req).await;
|
||||
}
|
||||
}
|
||||
metrics.extend(tokio::task::spawn_blocking(python::get_python_metrics).await??);
|
||||
}
|
||||
|
||||
Ok(Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("Content-Type", "text/plain; version=1.0.0; charset=utf-8")
|
||||
.body(Body::from(metrics))?)
|
||||
}
|
||||
|
||||
#[cfg(feature = "core")]
|
||||
mod python {
|
||||
use eyre::{Report, Result};
|
||||
use pyo3::{
|
||||
IntoPyObjectExt as _,
|
||||
ffi::c_str,
|
||||
prelude::*,
|
||||
types::{PyBytes, PyDict},
|
||||
};
|
||||
|
||||
pub(super) fn get_python_metrics() -> Result<Vec<u8>> {
|
||||
let metrics = Python::attach(|py| {
|
||||
let locals = PyDict::new(py);
|
||||
Python::run(
|
||||
py,
|
||||
c_str!(
|
||||
r#"
|
||||
from prometheus_client import (
|
||||
CollectorRegistry,
|
||||
generate_latest,
|
||||
multiprocess,
|
||||
)
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
output = generate_latest(registry)
|
||||
"#
|
||||
),
|
||||
None,
|
||||
Some(&locals),
|
||||
)?;
|
||||
let metrics = locals
|
||||
.get_item("output")?
|
||||
.unwrap_or(PyBytes::new(py, &[]).into_bound_py_any(py)?)
|
||||
.cast::<PyBytes>()
|
||||
.map_or_else(|_| PyBytes::new(py, &[]), |v| v.to_owned())
|
||||
.as_bytes()
|
||||
.to_owned();
|
||||
Ok::<_, Report>(metrics)
|
||||
})?;
|
||||
Ok::<_, Report>(metrics)
|
||||
}
|
||||
}
|
||||
84
src/metrics/mod.rs
Normal file
84
src/metrics/mod.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use std::{env::temp_dir, os::unix, sync::Arc, time::Duration};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use axum::{Router, routing::any};
|
||||
use eyre::Result;
|
||||
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Tasks},
|
||||
axum::{router::wrap_router, server},
|
||||
config,
|
||||
};
|
||||
#[cfg(feature = "core")]
|
||||
use crate::{server::Server, worker::Workers};
|
||||
|
||||
mod handlers;
|
||||
|
||||
pub(crate) struct Metrics {
|
||||
prometheus: PrometheusHandle,
|
||||
#[cfg(feature = "core")]
|
||||
pub(crate) server: ArcSwapOption<Server>,
|
||||
#[cfg(feature = "core")]
|
||||
pub(crate) workers: ArcSwapOption<Workers>,
|
||||
}
|
||||
|
||||
impl Metrics {
|
||||
fn new() -> Result<Self> {
|
||||
let prometheus = PrometheusBuilder::new()
|
||||
.with_recommended_naming(true)
|
||||
.install_recorder()?;
|
||||
Ok(Self {
|
||||
prometheus,
|
||||
#[cfg(feature = "core")]
|
||||
server: ArcSwapOption::empty(),
|
||||
#[cfg(feature = "core")]
|
||||
workers: ArcSwapOption::empty(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_upkeep(arbiter: Arbiter, state: Arc<Metrics>) -> Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
let state_clone = Arc::clone(&state);
|
||||
tokio::task::spawn_blocking(move || state_clone.prometheus.run_upkeep()).await?;
|
||||
},
|
||||
() = arbiter.shutdown() => return Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_router(state: Arc<Metrics>) -> Router {
|
||||
wrap_router(
|
||||
Router::new()
|
||||
.fallback(any(handlers::metrics_handler))
|
||||
.with_state(state),
|
||||
true,
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
|
||||
let arbiter = tasks.arbiter();
|
||||
let metrics = Arc::new(Metrics::new()?);
|
||||
let router = build_router(Arc::clone(&metrics));
|
||||
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::run_upkeep", module_path!(),))
|
||||
.spawn(run_upkeep(arbiter, Arc::clone(&metrics)))?;
|
||||
|
||||
for addr in config::get().listen.metrics.iter().copied() {
|
||||
server::start_plain(tasks, "metrics", router.clone(), addr)?;
|
||||
}
|
||||
|
||||
server::start_unix(
|
||||
tasks,
|
||||
"metrics",
|
||||
router,
|
||||
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik-metrics.sock"))?,
|
||||
)?;
|
||||
|
||||
Ok(metrics)
|
||||
}
|
||||
78
src/mode.rs
Normal file
78
src/mode.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use std::{
|
||||
env,
|
||||
path::PathBuf,
|
||||
sync::atomic::{AtomicU8, Ordering},
|
||||
};
|
||||
|
||||
use eyre::Result;
|
||||
|
||||
static MODE: AtomicU8 = AtomicU8::new(0);
|
||||
|
||||
fn mode_path() -> PathBuf {
|
||||
env::temp_dir().join("authentik-mode")
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
#[repr(u8)]
|
||||
pub(crate) enum Mode {
|
||||
#[cfg(feature = "core")]
|
||||
AllInOne = 0,
|
||||
#[cfg(feature = "core")]
|
||||
Server = 1,
|
||||
#[cfg(feature = "core")]
|
||||
Worker = 2,
|
||||
#[cfg(feature = "proxy")]
|
||||
Proxy = 3,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Mode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
#[cfg(feature = "core")]
|
||||
Self::AllInOne => write!(f, "allinone"),
|
||||
#[cfg(feature = "core")]
|
||||
Self::Server => write!(f, "server"),
|
||||
#[cfg(feature = "core")]
|
||||
Self::Worker => write!(f, "worker"),
|
||||
#[cfg(feature = "proxy")]
|
||||
Self::Proxy => write!(f, "proxy"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Mode> for u8 {
|
||||
#[expect(clippy::as_conversions, reason = "repr of enum is u8")]
|
||||
fn from(value: Mode) -> Self {
|
||||
value as Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Mode {
|
||||
pub(crate) fn get() -> Self {
|
||||
match MODE.load(Ordering::Relaxed) {
|
||||
#[cfg(feature = "core")]
|
||||
0 => Self::AllInOne,
|
||||
#[cfg(feature = "core")]
|
||||
1 => Self::Server,
|
||||
#[cfg(feature = "core")]
|
||||
2 => Self::Worker,
|
||||
#[cfg(feature = "proxy")]
|
||||
3 => Self::Proxy,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set(mode: Self) -> Result<()> {
|
||||
std::fs::write(mode_path(), mode.to_string())?;
|
||||
MODE.store(mode.into(), Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn is_core() -> bool {
|
||||
match Self::get() {
|
||||
#[cfg(feature = "core")]
|
||||
Self::AllInOne | Self::Server | Self::Worker => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
48
src/proxy/mod.rs
Normal file
48
src/proxy/mod.rs
Normal file
@@ -0,0 +1,48 @@
|
||||
use argh::FromArgs;
|
||||
use axum::extract::Request;
|
||||
use eyre::Result;
|
||||
|
||||
use crate::arbiter::{Arbiter, Tasks};
|
||||
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
/// Run the authentik proxy outpost.
|
||||
#[argh(subcommand, name = "proxy")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(crate) struct Cli {}
|
||||
|
||||
pub(crate) mod tls {
|
||||
use std::sync::Arc;
|
||||
|
||||
use rustls::{server::ClientHello, sign::CertifiedKey};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CertResolver;
|
||||
|
||||
impl CertResolver {
|
||||
#[expect(clippy::unused_self, reason = "still WIP")]
|
||||
pub(crate) fn resolve(&self, _client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn can_handle(_request: &Request) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
pub(crate) async fn ignore_me(arbiter: Arbiter) -> Result<()> {
|
||||
arbiter.shutdown().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<()> {
|
||||
let arbiter = tasks.arbiter();
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::ignore_me", module_path!(),))
|
||||
.spawn(ignore_me(arbiter))?;
|
||||
Ok(())
|
||||
}
|
||||
448
src/server/core.rs
Normal file
448
src/server/core.rs
Normal file
@@ -0,0 +1,448 @@
|
||||
use std::sync::{Arc, LazyLock, atomic::Ordering};
|
||||
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{
|
||||
HeaderName, HeaderValue, StatusCode, Uri,
|
||||
header::{ACCEPT, CONTENT_TYPE, HOST, LOCATION, RETRY_AFTER},
|
||||
},
|
||||
middleware::{Next, from_fn},
|
||||
response::{IntoResponse, Response},
|
||||
routing::any,
|
||||
};
|
||||
use http_body_util::BodyExt as _;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
axum::{
|
||||
accept::tls::TlsState,
|
||||
error::Result,
|
||||
extract::{client_ip::ClientIp, host::Host, scheme::Scheme, trusted_proxy::TrustedProxy},
|
||||
router::wrap_router,
|
||||
},
|
||||
config, db,
|
||||
server::{
|
||||
GUNICORN_READY, Server,
|
||||
core::websockets::{handle_websocket_upgrade, is_websocket_upgrade},
|
||||
},
|
||||
};
|
||||
|
||||
static STARTUP_RESPONSE_JSON: LazyLock<Response<String>> = LazyLock::new(|| {
|
||||
Response::builder()
|
||||
.status(StatusCode::SERVICE_UNAVAILABLE)
|
||||
.header(RETRY_AFTER, "5")
|
||||
.header(CONTENT_TYPE, "application/json")
|
||||
.body(
|
||||
json!({
|
||||
"error": "authentik starting",
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.expect("infallible")
|
||||
});
|
||||
|
||||
static STARTUP_RESPONSE_HTML: LazyLock<Response<String>> = LazyLock::new(|| {
|
||||
Response::builder()
|
||||
.status(StatusCode::SERVICE_UNAVAILABLE)
|
||||
.header(CONTENT_TYPE, "text/html")
|
||||
.body(include_str!("../../web/dist/standalone/loading/startup.html").to_owned())
|
||||
.expect("infallible")
|
||||
});
|
||||
|
||||
static STARTUP_RESPONSE_PLAIN: LazyLock<Response<String>> = LazyLock::new(|| {
|
||||
Response::builder()
|
||||
.status(StatusCode::SERVICE_UNAVAILABLE)
|
||||
.header(CONTENT_TYPE, "text/plain")
|
||||
.body("authentik starting".to_owned())
|
||||
.expect("infallible")
|
||||
});
|
||||
|
||||
const SERVER: HeaderName = HeaderName::from_static("server");
|
||||
const X_FORWARDED_CLIENT_CERT: HeaderName = HeaderName::from_static("x-forwarded-client-cert");
|
||||
const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
|
||||
const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-proto");
|
||||
const X_POWERED_BY: HeaderName = HeaderName::from_static("x-powered-by");
|
||||
|
||||
const FORWARD_ALWAYS_REMOVED_HEADERS: [HeaderName; 7] = [
|
||||
HeaderName::from_static("forwarded"),
|
||||
HeaderName::from_static("host"),
|
||||
X_FORWARDED_FOR,
|
||||
HeaderName::from_static("x-forwarded-host"),
|
||||
X_FORWARDED_PROTO,
|
||||
HeaderName::from_static("x-forwarded-scheme"),
|
||||
HeaderName::from_static("x-real-ip"),
|
||||
];
|
||||
const FORWARD_REMOVED_HEADERS_IF_UNTRUSTED: [HeaderName; 3] = [
|
||||
HeaderName::from_static("ssl-client-cert"), // nginx-ingress
|
||||
HeaderName::from_static("x-forwarded-tls-client-cert"), // traefik
|
||||
X_FORWARDED_CLIENT_CERT, // envoy
|
||||
];
|
||||
|
||||
fn startup_response(accept_header: &str) -> Response {
|
||||
let response = if accept_header.contains("application/json") {
|
||||
STARTUP_RESPONSE_JSON.clone()
|
||||
} else if accept_header.contains("text/html") {
|
||||
STARTUP_RESPONSE_HTML.clone()
|
||||
} else {
|
||||
STARTUP_RESPONSE_PLAIN.clone()
|
||||
};
|
||||
|
||||
let (parts, body) = response.into_parts();
|
||||
Response::from_parts(parts, body.into())
|
||||
}
|
||||
|
||||
async fn forward_request(
|
||||
ClientIp(client_ip): ClientIp,
|
||||
Host(host): Host,
|
||||
Scheme(scheme): Scheme,
|
||||
State(server): State<Arc<Server>>,
|
||||
TrustedProxy(trusted_proxy): TrustedProxy,
|
||||
tls_state: Option<Extension<TlsState>>,
|
||||
mut request: Request,
|
||||
) -> Result<Response> {
|
||||
let accept_header = request
|
||||
.headers()
|
||||
.get(ACCEPT)
|
||||
.map(|v| v.to_str().unwrap_or_default().to_owned())
|
||||
.unwrap_or_default();
|
||||
|
||||
if !GUNICORN_READY.load(Ordering::Relaxed) {
|
||||
return Ok(startup_response(&accept_header));
|
||||
}
|
||||
|
||||
let uri = Uri::builder()
|
||||
.scheme("http")
|
||||
.authority("localhost:8000")
|
||||
.path_and_query(
|
||||
request
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|x| x.as_str())
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
.build()?;
|
||||
*request.uri_mut() = uri;
|
||||
|
||||
for header_name in FORWARD_ALWAYS_REMOVED_HEADERS {
|
||||
request.headers_mut().remove(header_name);
|
||||
}
|
||||
if !trusted_proxy {
|
||||
for header_name in FORWARD_REMOVED_HEADERS_IF_UNTRUSTED {
|
||||
request.headers_mut().remove(header_name);
|
||||
}
|
||||
}
|
||||
|
||||
request.headers_mut().insert(
|
||||
X_FORWARDED_FOR,
|
||||
HeaderValue::from_str(&client_ip.to_string())?,
|
||||
);
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(HOST, HeaderValue::from_str(&host)?);
|
||||
request
|
||||
.headers_mut()
|
||||
.insert(X_FORWARDED_PROTO, HeaderValue::from_str(scheme.as_ref())?);
|
||||
|
||||
if is_websocket_upgrade(request.headers()) {
|
||||
return handle_websocket_upgrade(request, server).await;
|
||||
}
|
||||
|
||||
if let Some(tls_state) = tls_state
|
||||
&& let Some(peer_certificates) = &tls_state.peer_certificates
|
||||
{
|
||||
let xfcc = peer_certificates
|
||||
.iter()
|
||||
.map(|cert| {
|
||||
let pem_encoded = pem::encode(&pem::Pem::new("CERTIFICATE", cert.as_ref()));
|
||||
let url_encoded: String =
|
||||
url::form_urlencoded::byte_serialize(pem_encoded.as_bytes()).collect();
|
||||
format!("Cert={url_encoded}")
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
request
|
||||
.headers_mut()
|
||||
.insert("X_FORWARDED_CLIENT_CERT", HeaderValue::from_str(&xfcc)?);
|
||||
}
|
||||
|
||||
match server.client.request(request).await {
|
||||
Ok(res) => {
|
||||
let (parts, body) = res.into_parts();
|
||||
Ok(Response::from_parts(
|
||||
parts,
|
||||
Body::from_stream(body.into_data_stream()),
|
||||
))
|
||||
}
|
||||
Err(_) => Ok(startup_response(&accept_header)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_gunicorn_router(server: Arc<Server>) -> Router {
|
||||
wrap_router(
|
||||
Router::new().fallback(forward_request).with_state(server),
|
||||
config::get().debug, // enable tracing only in debug mode
|
||||
)
|
||||
}
|
||||
|
||||
async fn powered_by_middleware(request: Request, next: Next) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
response.headers_mut().remove(SERVER);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(X_POWERED_BY, HeaderValue::from_static("authentik"));
|
||||
response
|
||||
}
|
||||
|
||||
async fn health_ready(State(server): State<Arc<Server>>) -> impl IntoResponse {
|
||||
#[expect(clippy::if_same_then_else, reason = "For easier reading")]
|
||||
if !server.is_alive().await {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
} else if sqlx::query("SELECT 1").execute(db::get()).await.is_err() {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
} else if let Some(workers) = server.workers.load_full()
|
||||
&& !workers.are_alive().await
|
||||
{
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
} else {
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("http://localhost:8000/-/health/ready/")
|
||||
.header(HOST, "localhost")
|
||||
.body(Body::from(""));
|
||||
if let Ok(req) = req
|
||||
&& let Ok(res) = server.client.request(req).await
|
||||
{
|
||||
res.status()
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn build_router(server: Arc<Server>) -> Router {
|
||||
let router = wrap_router(
|
||||
Router::new()
|
||||
.route("/-/metrics/", any((StatusCode::NOT_FOUND, "not found")))
|
||||
.route("/-/health/ready/", any(health_ready))
|
||||
.with_state(Arc::clone(&server))
|
||||
.merge(super::r#static::build_router()),
|
||||
true,
|
||||
)
|
||||
.merge(build_gunicorn_router(server))
|
||||
.layer(from_fn(powered_by_middleware));
|
||||
let path = &config::get().web.path;
|
||||
if config::get().web.path == "/" {
|
||||
router
|
||||
} else {
|
||||
Router::new()
|
||||
.route(
|
||||
"/",
|
||||
any(
|
||||
async || match HeaderValue::try_from(&config::get().web.path) {
|
||||
Ok(location) => (StatusCode::FOUND, [(LOCATION, location)]).into_response(),
|
||||
Err(err) => {
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
|
||||
}
|
||||
},
|
||||
),
|
||||
)
|
||||
.nest(path, router)
|
||||
}
|
||||
}
|
||||
|
||||
mod websockets {
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{
|
||||
HeaderMap, HeaderValue, StatusCode,
|
||||
header::{
|
||||
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||
},
|
||||
},
|
||||
response::{IntoResponse as _, Response},
|
||||
};
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::{net::UnixStream, sync::mpsc};
|
||||
use tokio_tungstenite::{
|
||||
WebSocketStream, client_async,
|
||||
tungstenite::{Message, handshake::derive_accept_key, protocol::Role},
|
||||
};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use crate::{
|
||||
axum::error::{AppError, Result},
|
||||
server::Server,
|
||||
};
|
||||
|
||||
pub(super) fn is_websocket_upgrade(headers: &HeaderMap<HeaderValue>) -> bool {
|
||||
let has_upgrade = headers
|
||||
.get(UPGRADE)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| v.eq_ignore_ascii_case("websocket"));
|
||||
|
||||
let has_connection = headers
|
||||
.get(CONNECTION)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.is_some_and(|v| {
|
||||
v.split(',')
|
||||
.any(|part| part.trim().eq_ignore_ascii_case("upgrade"))
|
||||
});
|
||||
|
||||
let has_websocket_key = headers.contains_key(SEC_WEBSOCKET_KEY);
|
||||
let has_websocket_version = headers.contains_key(SEC_WEBSOCKET_VERSION);
|
||||
|
||||
has_upgrade && has_connection && has_websocket_key && has_websocket_version
|
||||
}
|
||||
|
||||
pub(super) async fn handle_websocket_upgrade(
|
||||
request: Request,
|
||||
server: Arc<Server>,
|
||||
) -> Result<Response> {
|
||||
let Some(ws_key) = request
|
||||
.headers()
|
||||
.get(SEC_WEBSOCKET_KEY)
|
||||
.and_then(|key| key.to_str().ok())
|
||||
else {
|
||||
return Ok((StatusCode::BAD_REQUEST, "").into_response());
|
||||
};
|
||||
|
||||
let ws_accept = derive_accept_key(ws_key.as_bytes());
|
||||
|
||||
let path_q = request
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|x| x.as_str())
|
||||
.unwrap_or_default();
|
||||
let uri = format!("ws://localhost:8000{path_q}");
|
||||
|
||||
let mut ws_request =
|
||||
tokio_tungstenite::tungstenite::handshake::client::Request::builder().uri(uri);
|
||||
for (k, v) in request.headers() {
|
||||
ws_request = ws_request.header(k.as_str(), v);
|
||||
}
|
||||
let ws_request = ws_request.body(())?;
|
||||
|
||||
let response = Response::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(UPGRADE, "websocket")
|
||||
.header(CONNECTION, "upgrade")
|
||||
.header(SEC_WEBSOCKET_ACCEPT, ws_accept)
|
||||
.body(Body::empty())?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = handle_websocket_connection(request, server, ws_request).await {
|
||||
warn!("WebSocket connection error: {}", err.0);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn handle_websocket_connection(
|
||||
request: Request,
|
||||
server: Arc<Server>,
|
||||
ws_request: tokio_tungstenite::tungstenite::handshake::client::Request,
|
||||
) -> Result<()> {
|
||||
let upgraded = hyper::upgrade::on(request).await?;
|
||||
let io = TokioIo::new(upgraded);
|
||||
let client_ws = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
|
||||
|
||||
let upstream_ws = {
|
||||
let stream = UnixStream::connect(&server.socket_path).await?;
|
||||
let (ws_stream, _) = client_async(ws_request, stream).await?;
|
||||
ws_stream
|
||||
};
|
||||
|
||||
let (mut client_sender, mut client_receiver) = client_ws.split();
|
||||
let (mut upstream_sender, mut upstream_receiver) = upstream_ws.split();
|
||||
|
||||
let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
|
||||
let close_tx_upstream = close_tx.clone();
|
||||
|
||||
let client_to_upstream = tokio::spawn(async move {
|
||||
let mut client_closed = false;
|
||||
while let Some(msg) = client_receiver.next().await {
|
||||
let msg = msg?;
|
||||
match msg {
|
||||
Message::Close(_) => {
|
||||
if !client_closed {
|
||||
upstream_sender.send(Message::Close(None)).await?;
|
||||
let _ = close_tx.send(()).await;
|
||||
client_closed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg @ (Message::Binary(_)
|
||||
| Message::Text(_)
|
||||
| Message::Ping(_)
|
||||
| Message::Pong(_)) => {
|
||||
if !client_closed {
|
||||
upstream_sender.send(msg).await?;
|
||||
}
|
||||
}
|
||||
Message::Frame(_) => {}
|
||||
}
|
||||
}
|
||||
if !client_closed {
|
||||
upstream_sender.send(Message::Close(None)).await?;
|
||||
let _ = close_tx.send(()).await;
|
||||
}
|
||||
Ok::<_, AppError>(())
|
||||
});
|
||||
|
||||
let upstream_to_client = tokio::spawn(async move {
|
||||
let mut upstream_closed = false;
|
||||
while let Some(msg) = upstream_receiver.next().await {
|
||||
let msg = msg?;
|
||||
match msg {
|
||||
Message::Close(_) => {
|
||||
if !upstream_closed {
|
||||
client_sender.send(Message::Close(None)).await?;
|
||||
let _ = close_tx_upstream.send(()).await;
|
||||
upstream_closed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg @ (Message::Binary(_)
|
||||
| Message::Text(_)
|
||||
| Message::Ping(_)
|
||||
| Message::Pong(_)) => {
|
||||
if !upstream_closed {
|
||||
client_sender.send(msg).await?;
|
||||
}
|
||||
}
|
||||
Message::Frame(_) => {}
|
||||
}
|
||||
}
|
||||
if !upstream_closed {
|
||||
client_sender.send(Message::Close(None)).await?;
|
||||
let _ = close_tx_upstream.send(()).await;
|
||||
}
|
||||
Ok::<_, AppError>(())
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = close_rx.recv() => {
|
||||
trace!("WebSocket connection closed gracefully");
|
||||
},
|
||||
res = client_to_upstream => {
|
||||
if let Err(err) = res {
|
||||
debug!("Client to upstream task failed: {:?}", err);
|
||||
}
|
||||
}
|
||||
res = upstream_to_client => {
|
||||
if let Err(err) = res {
|
||||
debug!("Upstream to client task failed: {:?}", err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
255
src/server/mod.rs
Normal file
255
src/server/mod.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use std::{
|
||||
env::temp_dir,
|
||||
os::unix,
|
||||
path::PathBuf,
|
||||
process::Stdio,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use arc_swap::ArcSwapOption;
|
||||
use argh::FromArgs;
|
||||
use axum::{Router, body::Body, extract::Request, http::status::StatusCode, routing::any};
|
||||
use eyre::{Result, eyre};
|
||||
use hyper_unix_socket::UnixSocketConnector;
|
||||
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
|
||||
use nix::{
|
||||
sys::signal::{Signal, kill},
|
||||
unistd::Pid,
|
||||
};
|
||||
use tokio::{
|
||||
net::UnixStream,
|
||||
process::{Child, Command},
|
||||
signal::unix::SignalKind,
|
||||
sync::{Mutex, broadcast::error::RecvError},
|
||||
time::Instant,
|
||||
};
|
||||
use tower::ServiceExt as _;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Tasks},
|
||||
axum::server,
|
||||
config,
|
||||
worker::Workers,
|
||||
};
|
||||
|
||||
pub(super) static GUNICORN_READY: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
pub(crate) mod core;
|
||||
mod r#static;
|
||||
mod tls;
|
||||
|
||||
#[derive(Debug, Default, FromArgs, PartialEq)]
|
||||
/// Run the authentik server.
|
||||
#[argh(subcommand, name = "server")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(super) struct Cli {}
|
||||
|
||||
pub(crate) struct Server {
|
||||
gunicorn: Mutex<Child>,
|
||||
socket_path: PathBuf,
|
||||
pub(crate) client: Client<UnixSocketConnector<PathBuf>, Body>,
|
||||
pub(crate) workers: ArcSwapOption<Workers>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
fn new(socket_path: PathBuf) -> Result<Self> {
|
||||
info!("starting gunicorn");
|
||||
let gunicorn = Command::new("gunicorn")
|
||||
.args([
|
||||
"--bind",
|
||||
&format!("unix://{}", socket_path.display()),
|
||||
"-c",
|
||||
"./lifecycle/gunicorn.conf.py",
|
||||
"authentik.root.asgi:application",
|
||||
])
|
||||
.kill_on_drop(true)
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit())
|
||||
.spawn()?;
|
||||
|
||||
let client = Client::builder(TokioExecutor::new())
|
||||
.pool_idle_timeout(Duration::from_secs(60))
|
||||
.set_host(false)
|
||||
.build(UnixSocketConnector::new(socket_path.clone()));
|
||||
|
||||
Ok(Self {
|
||||
gunicorn: Mutex::new(gunicorn),
|
||||
socket_path,
|
||||
client,
|
||||
workers: ArcSwapOption::empty(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn shutdown(&self, signal: Signal) -> Result<()> {
|
||||
trace!(
|
||||
signal = signal.as_str(),
|
||||
"sending shutdown signal to gunicorn"
|
||||
);
|
||||
let mut gunicorn = self.gunicorn.lock().await;
|
||||
if let Some(id) = gunicorn.id() {
|
||||
kill(Pid::from_raw(id.cast_signed()), signal)?;
|
||||
}
|
||||
gunicorn.wait().await?;
|
||||
drop(gunicorn);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn graceful_shutdown(&self) -> Result<()> {
|
||||
info!("gracefully shutting down gunicorn");
|
||||
self.shutdown(Signal::SIGTERM).await
|
||||
}
|
||||
|
||||
async fn fast_shutdown(&self) -> Result<()> {
|
||||
info!("immediately shutting down gunicorn");
|
||||
self.shutdown(Signal::SIGINT).await
|
||||
}
|
||||
|
||||
async fn is_alive(&self) -> bool {
|
||||
let try_wait = self.gunicorn.lock().await.try_wait();
|
||||
match try_wait {
|
||||
Ok(Some(code)) => {
|
||||
warn!("gunicorn has exited with status {code}");
|
||||
false
|
||||
}
|
||||
Ok(None) => true,
|
||||
Err(err) => {
|
||||
warn!("failed to check the status of gunicorn process, ignoring: {err}");
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_socket_ready(&self) -> bool {
|
||||
let result = UnixStream::connect(&self.socket_path).await;
|
||||
trace!("checking if gunicorn is ready: {result:?}");
|
||||
result.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
async fn watch_server(arbiter: Arbiter, server: Arc<Server>) -> Result<()> {
|
||||
info!("starting server watcher");
|
||||
let mut signals_rx = arbiter.signals_subscribe();
|
||||
loop {
|
||||
tokio::select! {
|
||||
signal = signals_rx.recv() => {
|
||||
match signal {
|
||||
Ok(signal) => {
|
||||
if signal == SignalKind::user_defined1() {
|
||||
info!("gunicorn notified us ready, marked ready for operation");
|
||||
GUNICORN_READY.store(true, Ordering::Relaxed);
|
||||
arbiter.mark_gunicorn_ready();
|
||||
}
|
||||
},
|
||||
Err(RecvError::Lagged(_)) => {},
|
||||
Err(RecvError::Closed) => {
|
||||
warn!("error receiving signals");
|
||||
return Err(RecvError::Closed.into());
|
||||
}
|
||||
}
|
||||
},
|
||||
() = tokio::time::sleep(Duration::from_secs(1)), if !GUNICORN_READY.load(Ordering::Relaxed) => {
|
||||
// On some platforms the SIGUSR1 can be missed.
|
||||
// Fall back to probing the gunicorn unix socket and mark ready once it accepts connections.
|
||||
if server.is_socket_ready().await {
|
||||
info!("gunicorn socket is accepting connections, marked ready for operation");
|
||||
GUNICORN_READY.store(true, Ordering::Relaxed);
|
||||
arbiter.mark_gunicorn_ready();
|
||||
}
|
||||
},
|
||||
() = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
if !server.is_alive().await {
|
||||
return Err(eyre!("gunicorn has exited unexpectedly"));
|
||||
}
|
||||
},
|
||||
() = arbiter.fast_shutdown() => {
|
||||
server.fast_shutdown().await?;
|
||||
return Ok(());
|
||||
},
|
||||
() = arbiter.graceful_shutdown() => {
|
||||
server.graceful_shutdown().await?;
|
||||
return Ok(());
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_router(server: Arc<Server>) -> Router {
|
||||
let core_router = core::build_router(server);
|
||||
let proxy_router: Option<Router> = None;
|
||||
|
||||
let config = config::get();
|
||||
let timeout = durstr::parse(&config.web.timeout_http_read_header)
|
||||
.expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_read).expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_write).expect("Invalid duration in http timeout")
|
||||
+ durstr::parse(&config.web.timeout_http_idle).expect("Invalid duration in http timeout");
|
||||
let timeout_layer = TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout);
|
||||
|
||||
Router::new()
|
||||
.fallback(any(async |request: Request<Body>| {
|
||||
metrics::describe_histogram!(
|
||||
"authentik_main_request_duration",
|
||||
metrics::Unit::Seconds,
|
||||
"API request latencies in seconds"
|
||||
);
|
||||
let now = Instant::now();
|
||||
if let Some(proxy_router) = proxy_router
|
||||
&& crate::proxy::can_handle(&request)
|
||||
{
|
||||
let res = proxy_router.oneshot(request).await;
|
||||
metrics::histogram!("authentik_main_request_duration", "dest" => "embedded_outpost")
|
||||
.record(now.elapsed());
|
||||
res
|
||||
} else {
|
||||
let res = core_router.oneshot(request).await;
|
||||
metrics::histogram!("authentik_main_request_duration", "dest" => "core")
|
||||
.record(now.elapsed());
|
||||
res
|
||||
}
|
||||
}))
|
||||
.layer(timeout_layer)
|
||||
}
|
||||
|
||||
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
|
||||
let config = config::get();
|
||||
let arbiter = tasks.arbiter();
|
||||
|
||||
let server = Arc::new(Server::new(temp_dir().join("authentik-gunicorn.sock"))?);
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_server", module_path!()))
|
||||
.spawn(watch_server(arbiter.clone(), Arc::clone(&server)))?;
|
||||
|
||||
let router = build_router(Arc::clone(&server));
|
||||
|
||||
for addr in config.listen.http.iter().copied() {
|
||||
server::start_plain(tasks, "server", router.clone(), addr)?;
|
||||
}
|
||||
|
||||
let tls_config = tls::make_initial_tls_config()?;
|
||||
for addr in config.listen.https.iter().copied() {
|
||||
server::start_tls(tasks, "tls", router.clone(), addr, tls_config.clone())?;
|
||||
}
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::tls::watch_tls_config", module_path!(),))
|
||||
.spawn(tls::watch_tls_config(arbiter, tls_config))?;
|
||||
|
||||
server::start_unix(
|
||||
tasks,
|
||||
"server",
|
||||
router,
|
||||
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
|
||||
)?;
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
255
src/server/static.rs
Normal file
255
src/server/static.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use std::fmt::Write as _;
|
||||
|
||||
use aws_lc_rs::digest;
|
||||
use axum::{
|
||||
Router,
|
||||
extract::{Query, Request, State},
|
||||
http::{
|
||||
HeaderValue, StatusCode,
|
||||
header::{CACHE_CONTROL, CONTENT_SECURITY_POLICY, VARY},
|
||||
},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse as _, Response},
|
||||
routing::any,
|
||||
};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
|
||||
use percent_encoding::percent_decode_str;
|
||||
use serde::Deserialize;
|
||||
use time::OffsetDateTime;
|
||||
use tower_http::{
|
||||
compression::{CompressionLayer, predicate::SizeAbove},
|
||||
services::fs::ServeDir,
|
||||
};
|
||||
|
||||
use crate::config;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StorageClaims {
|
||||
exp: Option<i64>,
|
||||
nbf: Option<i64>,
|
||||
path: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct StorageTokenQuery {
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
fn is_storage_token_valid(usage: &str, secret_key: &str, request: &Request) -> bool {
|
||||
// Use typed query parsing so `token` is percent-decoded before JWT parsing.
|
||||
let token_string = match Query::<StorageTokenQuery>::try_from_uri(request.uri()) {
|
||||
Ok(query) => match query.0.token {
|
||||
Some(token) if !token.is_empty() => token,
|
||||
_ => return false,
|
||||
},
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let Ok(token_header) = decode_header(&token_string) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Must match what we use in authentik/admin/files/backends/file.py
|
||||
if token_header.alg != Algorithm::HS256 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Derive a per-usage key so media and reports tokens are not interchangeable.
|
||||
let key = format!("{secret_key}:{usage}");
|
||||
let key_digest = digest::digest(&digest::SHA256, key.as_bytes());
|
||||
let key_hex_digest = key_digest
|
||||
.as_ref()
|
||||
.iter()
|
||||
.fold(String::new(), |mut acc, b| {
|
||||
let _ = write!(acc, "{b:02x}");
|
||||
acc
|
||||
});
|
||||
|
||||
let mut validation = Validation::new(token_header.alg);
|
||||
validation.validate_exp = false;
|
||||
validation.validate_nbf = false;
|
||||
validation.validate_aud = false;
|
||||
validation.required_spec_claims.clear();
|
||||
|
||||
let claims = match decode::<StorageClaims>(
|
||||
&token_string,
|
||||
&DecodingKey::from_secret(key_hex_digest.as_bytes()),
|
||||
&validation,
|
||||
) {
|
||||
Ok(token) => token.claims,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
let now = OffsetDateTime::now_utc().unix_timestamp();
|
||||
if claims.exp.unwrap_or(0) < now {
|
||||
return false;
|
||||
}
|
||||
if claims.nbf.unwrap_or(now + 1) > now {
|
||||
return false;
|
||||
}
|
||||
|
||||
let Some(claim_path) = claims.path else {
|
||||
return false;
|
||||
};
|
||||
// Decode path before comparison so encoded URL segments cannot bypass path binding.
|
||||
let Ok(request_path) = percent_decode_str(request.uri().path()).decode_utf8() else {
|
||||
return false;
|
||||
};
|
||||
let request_path = request_path.trim_start_matches('/');
|
||||
let expected_path = format!("{usage}/{request_path}");
|
||||
if claim_path != expected_path {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StorageMiddlewareConfig {
|
||||
usage: &'static str,
|
||||
set_csp_header: bool,
|
||||
}
|
||||
|
||||
async fn storage_middleware(
|
||||
State(config): State<StorageMiddlewareConfig>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if !is_storage_token_valid(config.usage, &config::get().secret_key, &request) {
|
||||
return (StatusCode::NOT_FOUND, "404 page not found\n").into_response();
|
||||
}
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
if config.set_csp_header {
|
||||
// Since media is user-controlled, better be safe
|
||||
response.headers_mut().insert(
|
||||
CONTENT_SECURITY_POLICY,
|
||||
HeaderValue::from_static("default-src 'none'; style-src 'unsafe-inline'; sandbox"),
|
||||
);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
async fn static_header_middleware(request: Request, next: Next) -> Response {
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
response.headers_mut().insert(
|
||||
CACHE_CONTROL,
|
||||
HeaderValue::from_static("public, no-transform"),
|
||||
);
|
||||
response.headers_mut().insert(
|
||||
"X-authentik-version",
|
||||
HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(VARY, HeaderValue::from_static("X-authentik-version, Etag"));
|
||||
|
||||
// TODO: etag
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
pub(crate) fn build_router() -> Router {
|
||||
let config = config::get();
|
||||
|
||||
let mut router = Router::new().layer(middleware::from_fn(static_header_middleware));
|
||||
|
||||
let dist_fs = ServeDir::new("./web/dist/").append_index_html_on_directories(false);
|
||||
let static_fs = ServeDir::new("./web/authentik/").append_index_html_on_directories(false);
|
||||
|
||||
router = router.nest_service("/static/dist/", dist_fs.clone());
|
||||
router = router.nest_service("/static/authentik/", static_fs);
|
||||
|
||||
router = router.nest_service("/if/flow/{flow_slug}/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/admin/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/user/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs);
|
||||
|
||||
let default_backend = &config.storage.backend;
|
||||
let media_backend = config
|
||||
.storage
|
||||
.media
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.backend
|
||||
.unwrap_or_else(|| default_backend.clone());
|
||||
let reports_backend = config
|
||||
.storage
|
||||
.reports
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.backend
|
||||
.unwrap_or_else(|| default_backend.clone());
|
||||
|
||||
let default_path = &config.storage.file.path;
|
||||
|
||||
if media_backend == "file" {
|
||||
let media_path = config
|
||||
.storage
|
||||
.media
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.file
|
||||
.unwrap_or_default()
|
||||
.path
|
||||
.unwrap_or_else(|| default_path.clone())
|
||||
.join("media");
|
||||
|
||||
let media_fs = ServeDir::new(media_path).append_index_html_on_directories(false);
|
||||
let media_router =
|
||||
Router::new()
|
||||
.fallback_service(media_fs)
|
||||
.layer(middleware::from_fn_with_state(
|
||||
StorageMiddlewareConfig {
|
||||
usage: "media",
|
||||
set_csp_header: true,
|
||||
},
|
||||
storage_middleware,
|
||||
));
|
||||
router = router.nest("/files/media/", media_router);
|
||||
}
|
||||
|
||||
if reports_backend == "file" {
|
||||
let reports_path = config
|
||||
.storage
|
||||
.reports
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.file
|
||||
.unwrap_or_default()
|
||||
.path
|
||||
.unwrap_or_else(|| default_path.clone())
|
||||
.join("reports");
|
||||
|
||||
let reports_fs = ServeDir::new(reports_path).append_index_html_on_directories(false);
|
||||
let reports_router =
|
||||
Router::new()
|
||||
.fallback_service(reports_fs)
|
||||
.layer(middleware::from_fn_with_state(
|
||||
StorageMiddlewareConfig {
|
||||
usage: "reports",
|
||||
set_csp_header: false,
|
||||
},
|
||||
storage_middleware,
|
||||
));
|
||||
router = router.nest("/files/reports/", reports_router);
|
||||
}
|
||||
|
||||
router = router.route(
|
||||
"/robots.txt",
|
||||
any(async || include_str!("../../web/robots.txt")),
|
||||
);
|
||||
router = router.route(
|
||||
"/.well-known/security.txt",
|
||||
any(async || include_str!("../../web/security.txt")),
|
||||
);
|
||||
|
||||
router = router.layer(middleware::from_fn(static_header_middleware));
|
||||
|
||||
router = router.layer(CompressionLayer::new().compress_when(SizeAbove::new(32)));
|
||||
|
||||
router
|
||||
}
|
||||
147
src/server/tls.rs
Normal file
147
src/server/tls.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use axum_server::tls_rustls::RustlsConfig;
|
||||
use eyre::Result;
|
||||
use rcgen::PKCS_ECDSA_P256_SHA256;
|
||||
use rustls::{
|
||||
ServerConfig,
|
||||
server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::{arbiter::Arbiter, brands, proxy};
|
||||
|
||||
pub(super) fn make_initial_tls_config() -> Result<RustlsConfig> {
|
||||
let (cert, keypair) = self_signed::generate(&PKCS_ECDSA_P256_SHA256)?;
|
||||
Ok(RustlsConfig::from_config(Arc::new(
|
||||
ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(vec![cert.into()], keypair.into())?,
|
||||
)))
|
||||
}
|
||||
|
||||
async fn make_tls_config(fallback: Arc<CertifiedKey>) -> Result<ServerConfig> {
|
||||
let (core_resolver, roots) = brands::tls::make_cert_managers().await?;
|
||||
let cert_resolver = CertResolver {
|
||||
core_resolver,
|
||||
proxy_resolver: None,
|
||||
fallback,
|
||||
};
|
||||
|
||||
let client_cert_verifier = WebPkiClientVerifier::builder(Arc::new(roots))
|
||||
.allow_unauthenticated()
|
||||
.build()?;
|
||||
|
||||
Ok(ServerConfig::builder()
|
||||
.with_client_cert_verifier(client_cert_verifier)
|
||||
.with_cert_resolver(Arc::new(cert_resolver)))
|
||||
}
|
||||
|
||||
pub(super) async fn watch_tls_config(arbiter: Arbiter, config: RustlsConfig) -> Result<()> {
|
||||
tokio::select! {
|
||||
() = arbiter.gunicorn_ready() => {},
|
||||
() = arbiter.shutdown() => return Ok(()),
|
||||
}
|
||||
|
||||
let fallback = Arc::new(self_signed::generate_certifiedkey(&PKCS_ECDSA_P256_SHA256)?);
|
||||
|
||||
loop {
|
||||
match make_tls_config(Arc::clone(&fallback)).await {
|
||||
Ok(new_config) => {
|
||||
config.reload_from_config(Arc::new(new_config));
|
||||
debug!("reloaded tls config");
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("error while reloading tls config {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(60)) => {},
|
||||
() = arbiter.shutdown() => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CertResolver {
|
||||
core_resolver: brands::tls::CertResolver,
|
||||
proxy_resolver: Option<proxy::tls::CertResolver>,
|
||||
fallback: Arc<CertifiedKey>,
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::missing_trait_methods,
|
||||
reason = "the provided methods are sensible enough"
|
||||
)]
|
||||
impl ResolvesServerCert for CertResolver {
|
||||
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
if client_hello.server_name().is_none() {
|
||||
Some(Arc::clone(&self.fallback))
|
||||
} else if let Some(resolver) = &self.proxy_resolver
|
||||
&& let Some(cert) = resolver.resolve(&client_hello)
|
||||
{
|
||||
Some(cert)
|
||||
} else if let Some(cert) = self.core_resolver.resolve(&client_hello) {
|
||||
Some(cert)
|
||||
} else {
|
||||
Some(Arc::clone(&self.fallback))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod self_signed {
|
||||
use eyre::Result;
|
||||
use rcgen::{
|
||||
Certificate, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
|
||||
KeyPair, KeyUsagePurpose, SanType, SignatureAlgorithm,
|
||||
};
|
||||
use rustls::{
|
||||
crypto::aws_lc_rs::sign::any_supported_type,
|
||||
pki_types::{CertificateDer, PrivateKeyDer},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use time::{Duration, OffsetDateTime};
|
||||
|
||||
pub(super) fn generate(alg: &'static SignatureAlgorithm) -> Result<(Certificate, KeyPair)> {
|
||||
let signing_key = KeyPair::generate_for(alg)?;
|
||||
|
||||
let mut params = CertificateParams::default();
|
||||
params.not_before = OffsetDateTime::now_utc();
|
||||
params.not_after = OffsetDateTime::now_utc() + Duration::days(365);
|
||||
params.distinguished_name = {
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::OrganizationName, "authentik");
|
||||
dn.push(DnType::CommonName, "authentik default certificate");
|
||||
dn
|
||||
};
|
||||
params.subject_alt_names = vec![SanType::DnsName("*".try_into()?)];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
|
||||
let cert = params.self_signed(&signing_key)?;
|
||||
|
||||
Ok((cert, signing_key))
|
||||
}
|
||||
|
||||
pub(super) fn generate_certifiedkey(alg: &'static SignatureAlgorithm) -> Result<CertifiedKey> {
|
||||
let (cert, keypair) = generate(alg)?;
|
||||
|
||||
let cert_der = cert.der().to_vec();
|
||||
let key_der = keypair.serialize_der();
|
||||
|
||||
let private_key =
|
||||
PrivateKeyDer::try_from(key_der).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
|
||||
let signing_key =
|
||||
any_supported_type(&private_key).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
|
||||
|
||||
Ok(CertifiedKey::new(
|
||||
vec![CertificateDer::from(cert_der)],
|
||||
signing_key,
|
||||
))
|
||||
}
|
||||
}
|
||||
1
src/tokio/mod.rs
Normal file
1
src/tokio/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub(crate) mod proxy_protocol;
|
||||
22
src/tokio/proxy_protocol/LICENSE
Normal file
22
src/tokio/proxy_protocol/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 Authentik Security Inc.
|
||||
Copyright (c) 2023 Tibor Djurica Potpara
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
1
src/tokio/proxy_protocol/README.md
Normal file
1
src/tokio/proxy_protocol/README.md
Normal file
@@ -0,0 +1 @@
|
||||
This is a fork of https://github.com/tibordp/proxy-header/, with the sync code removed, the encoding code removed, and the ability to make the PROXY protocol optional.
|
||||
634
src/tokio/proxy_protocol/header.rs
Normal file
634
src/tokio/proxy_protocol/header.rs
Normal file
@@ -0,0 +1,634 @@
|
||||
use std::{borrow::Cow, fmt, net::SocketAddr, str::from_utf8};
|
||||
|
||||
use thiserror::Error;
|
||||
use tracing::instrument;
|
||||
|
||||
/// Protocol type
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
|
||||
pub(crate) enum Protocol {
|
||||
/// Stream protocol (TCP)
|
||||
Stream,
|
||||
/// Datagram protocol (UDP)
|
||||
Datagram,
|
||||
}
|
||||
|
||||
/// Address information from a PROXY protocol header
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
|
||||
pub(crate) struct Address {
|
||||
/// Protocol type
|
||||
pub(crate) protocol: Protocol,
|
||||
/// Source address (of the actual client)
|
||||
pub(crate) source: SocketAddr,
|
||||
/// Destination address (of the proxy)
|
||||
pub(crate) destination: SocketAddr,
|
||||
}
|
||||
|
||||
macro_rules! tlv {
|
||||
($self:expr, $kind:ident) => {{
|
||||
$self.tlvs().find_map(|f| match f {
|
||||
Ok(Tlv::$kind(v)) => Some(v),
|
||||
_ => None,
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! tlv_borrowed {
|
||||
($self:expr, $kind:ident) => {{
|
||||
$self.tlvs().find_map(|f| match f {
|
||||
Ok(Tlv::$kind(v)) => match v {
|
||||
// It is more ergonomic to return the borrowed value directly rather
|
||||
// than it wrapped in a `Cow::Borrowed`. We know that tlvs always borrows
|
||||
// so we can safely unwrap the `Cow::Borrowed` and return the borrowed value.
|
||||
Cow::Owned(_) => unreachable!(),
|
||||
Cow::Borrowed(v) => Some(v),
|
||||
},
|
||||
_ => None,
|
||||
})
|
||||
}};
|
||||
}
|
||||
|
||||
/// Iterator over PROXY protocol TLV fields
|
||||
pub(crate) struct Tlvs<'a> {
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::missing_trait_methods,
|
||||
reason = "we don't need to implement the other methods here"
|
||||
)]
|
||||
impl<'a> Iterator for Tlvs<'a> {
|
||||
type Item = Result<Tlv<'a>, ProxyProtocolError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.buf.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let kind = self.buf[0];
|
||||
match self
|
||||
.buf
|
||||
.get(1..3)
|
||||
.map(|s| -> usize { u16::from_be_bytes(s.try_into().expect("infallible")).into() })
|
||||
{
|
||||
Some(u) if u + 3 <= self.buf.len() => {
|
||||
let (ret, new) = self.buf.split_at(3 + u);
|
||||
self.buf = new;
|
||||
|
||||
Some(Tlv::decode(kind, &ret[3..]))
|
||||
}
|
||||
_ => {
|
||||
// Malformed TLV, cannot continue
|
||||
self.buf = &[];
|
||||
Some(Err(ProxyProtocolError::Invalid))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Typed TLV field
|
||||
///
|
||||
/// Represents the currently known types of TLV fields from the PROXY protocol specification.
|
||||
/// Non-recognized TLV fields are represented as [`Tlv::Custom`].
|
||||
#[non_exhaustive]
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub(crate) enum Tlv<'a> {
|
||||
/// Application-Layer Protocol Negotiation (ALPN). It is a byte sequence defining the upper
|
||||
/// layer protocol in use over the connection. The most common use case will be to pass the
|
||||
/// exact copy of the ALPN extension of the Transport Layer Security (TLS) protocol as defined
|
||||
/// by RFC7301.
|
||||
Alpn(Cow<'a, [u8]>),
|
||||
|
||||
/// Contains the host name value passed by the client, as an UTF-8 encoded string. In case of
|
||||
/// TLS being used on the client connection, this is the exact copy of the `server_name`
|
||||
/// extension as defined by RFC3546, section 3.1, often referred to as SNI. There are probably
|
||||
/// other situations where an authority can be mentionned on a connection without TLS being
|
||||
/// involved at all.
|
||||
Authority(Cow<'a, str>),
|
||||
|
||||
/// The value of the type `PP2_TYPE_CRC32C` is a 32-bit number storing the `CRC32c` checksum of
|
||||
/// the PROXY protocol header.
|
||||
///
|
||||
/// When the checksum is supported by the sender after constructing the header the sender MUST:
|
||||
///
|
||||
/// - initialize the checksum field to '0's.
|
||||
///
|
||||
/// - calculate the `CRC32c` checksum of the PROXY header as described in RFC4960, Appendix B.
|
||||
///
|
||||
/// - put the resultant value into the checksum field, and leave the rest of the bits unchanged.
|
||||
///
|
||||
/// If the checksum is provided as part of the PROXY header and the checksum functionality is
|
||||
/// supported by the receiver, the receiver MUST:
|
||||
///
|
||||
/// - store the received `CRC32c` checksum value aside.
|
||||
///
|
||||
/// - replace the 32 bits of the checksum field in the received PROXY header with all '0's and
|
||||
/// calculate a `CRC32c` checksum value of the whole PROXY header.
|
||||
///
|
||||
/// - verify that the calculated `CRC32c` checksum is the same as the received `CRC32c`
|
||||
/// checksum. If it is not, the receiver MUST treat the TCP connection providing the header as
|
||||
/// invalid.
|
||||
///
|
||||
/// The default procedure for handling an invalid TCP connection is to abort it.
|
||||
Crc32c(u32),
|
||||
|
||||
/// The TLV of this type should be ignored when parsed. The value is zero or more bytes. Can be
|
||||
/// used for data padding or alignment. Note that it can be used to align only by 3 or more
|
||||
/// bytes because a TLV can not be smaller than that.
|
||||
Noop(usize),
|
||||
|
||||
/// The value of the type `PP2_TYPE_UNIQUE_ID` is an opaque byte sequence of up to
|
||||
/// 128 bytes generated by the upstream proxy that uniquely identifies the connection.
|
||||
///
|
||||
/// The unique ID can be used to easily correlate connections across multiple layers of
|
||||
/// proxies, without needing to look up IP addresses and port numbers.
|
||||
UniqueId(Cow<'a, [u8]>),
|
||||
|
||||
/// SSL (TLS) information
|
||||
///
|
||||
/// See [`SslInfo`] for more information.
|
||||
Ssl(SslInfo<'a>),
|
||||
|
||||
/// The type `PP2_TYPE_NETNS` defines the value as the US-ASCII string representation of the
|
||||
/// namespace's name.
|
||||
Netns(Cow<'a, str>),
|
||||
|
||||
// The following can only appear as a sub-TLV of SslInfo
|
||||
/// SSL/TLS version
|
||||
SslVersion(Cow<'a, str>),
|
||||
/// In all cases, the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3)
|
||||
/// of the client certificate's Distinguished Name, is appended using the TLV format and the
|
||||
/// type `PP2_SUBTYPE_SSL_CN`. E.g. "example.com".
|
||||
SslCn(Cow<'a, str>),
|
||||
/// The second level TLV `PP2_SUBTYPE_SSL_CIPHER` provides the US-ASCII string name of the used
|
||||
/// cipher, for example "ECDHE-RSA-AES128-GCM-SHA256".
|
||||
SslCipher(Cow<'a, str>),
|
||||
/// The second level TLV `PP2_SUBTYPE_SSL_SIG_ALG` provides the US-ASCII string name of the
|
||||
/// algorithm used to sign the certificate presented by the frontend when the incoming
|
||||
/// connection was made over an SSL/TLS transport layer, for example "SHA256".
|
||||
SslSigAlg(Cow<'a, str>),
|
||||
/// The second level TLV `PP2_SUBTYPE_SSL_KEY_ALG` provides the US-ASCII string name of the
|
||||
/// algorithm used to generate the key of the certificate presented by the frontend when the
|
||||
/// incoming connection was made over an SSL/TLS transport layer, for example "RSA2048".
|
||||
SslKeyAlg(Cow<'a, str>),
|
||||
|
||||
/// Unrecognized or custom TLV field
|
||||
Custom(u8, Cow<'a, [u8]>),
|
||||
}
|
||||
|
||||
impl<'a> Tlv<'a> {
|
||||
fn decode(kind: u8, data: &'a [u8]) -> Result<Self, ProxyProtocolError> {
|
||||
match kind {
|
||||
0x01 => Ok(Self::Alpn(data.into())),
|
||||
0x02 => Ok(Self::Authority(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x03 => Ok(Self::Crc32c(u32::from_be_bytes(
|
||||
data.try_into().map_err(|_| ProxyProtocolError::Invalid)?,
|
||||
))),
|
||||
0x04 => Ok(Self::Noop(data.len())),
|
||||
0x05 => Ok(Self::UniqueId(data.into())),
|
||||
0x20 => Ok(Tlv::Ssl(SslInfo(
|
||||
*data.first().ok_or(ProxyProtocolError::Invalid)?,
|
||||
u32::from_be_bytes(
|
||||
data.get(1..5)
|
||||
.ok_or(ProxyProtocolError::Invalid)?
|
||||
.try_into()
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?,
|
||||
),
|
||||
data.get(5..).ok_or(ProxyProtocolError::Invalid)?.into(),
|
||||
))),
|
||||
0x21 => Ok(Self::SslVersion(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x22 => Ok(Self::SslCn(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x23 => Ok(Self::SslCipher(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x24 => Ok(Self::SslSigAlg(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x25 => Ok(Self::SslKeyAlg(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x30 => Ok(Self::Netns(
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
t => Ok(Self::Custom(t, data.into())),
|
||||
}
|
||||
}
|
||||
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn into_owned(self) -> Tlv<'static> {
|
||||
match self {
|
||||
Self::Alpn(v) => Tlv::Alpn(Cow::Owned(v.into_owned())),
|
||||
Self::Authority(v) => Tlv::Authority(Cow::Owned(v.into_owned())),
|
||||
Self::Crc32c(v) => Tlv::Crc32c(v),
|
||||
Self::Noop(v) => Tlv::Noop(v),
|
||||
Self::UniqueId(v) => Tlv::UniqueId(Cow::Owned(v.into_owned())),
|
||||
Self::Ssl(v) => Tlv::Ssl(v.into_owned()),
|
||||
Self::Netns(v) => Tlv::Netns(Cow::Owned(v.into_owned())),
|
||||
Self::SslVersion(v) => Tlv::SslVersion(Cow::Owned(v.into_owned())),
|
||||
Self::SslCn(v) => Tlv::SslCn(Cow::Owned(v.into_owned())),
|
||||
Self::SslCipher(v) => Tlv::SslCipher(Cow::Owned(v.into_owned())),
|
||||
Self::SslSigAlg(v) => Tlv::SslSigAlg(Cow::Owned(v.into_owned())),
|
||||
Self::SslKeyAlg(v) => Tlv::SslKeyAlg(Cow::Owned(v.into_owned())),
|
||||
Self::Custom(a, v) => Tlv::Custom(a, Cow::Owned(v.into_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// SSL information from a PROXY protocol header
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub(crate) struct SslInfo<'a>(u8, u32, Cow<'a, [u8]>);
|
||||
|
||||
impl SslInfo<'_> {
|
||||
/// Client connected over SSL/TLS
|
||||
///
|
||||
/// The `PP2_CLIENT_SSL` flag indicates that the client connected over SSL/TLS. When this field
|
||||
/// is present, the US-ASCII string representation of the TLS version is appended at the end of
|
||||
/// the field in the TLV format using the type `PP2_SUBTYPE_SSL_VERSION`.
|
||||
pub(crate) fn client_ssl(&self) -> bool {
|
||||
self.0 & 0x01 != 0
|
||||
}
|
||||
|
||||
/// Client certificate presented in the connection
|
||||
///
|
||||
/// `PP2_CLIENT_CERT_CONN` indicates that the client provided a certificate over the current
|
||||
/// connection.
|
||||
pub(crate) fn client_cert_conn(&self) -> bool {
|
||||
self.0 & 0x02 != 0
|
||||
}
|
||||
|
||||
/// Client certificate presented in the session
|
||||
///
|
||||
/// `PP2_CLIENT_CERT_SESS` indicates that the client provided a certificate at least once over
|
||||
/// the TLS session this connection belongs to.
|
||||
pub(crate) fn client_cert_sess(&self) -> bool {
|
||||
self.0 & 0x04 != 0
|
||||
}
|
||||
|
||||
/// Whether the certificate was verified
|
||||
///
|
||||
/// The verify field will be zero if the client presented a certificate and it was successfully
|
||||
/// verified, and non-zero otherwise.
|
||||
pub(crate) fn verify(&self) -> u32 {
|
||||
self.1
|
||||
}
|
||||
|
||||
/// Iterator over all TLV (type-length-value) fields
|
||||
pub(crate) fn tlvs(&self) -> Tlvs<'_> {
|
||||
Tlvs { buf: &self.2 }
|
||||
}
|
||||
|
||||
// Convenience accessors for common TLVs
|
||||
|
||||
/// SSL version
|
||||
///
|
||||
/// See [`Tlv::SslVersion`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn version(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, SslVersion)
|
||||
}
|
||||
|
||||
/// SSL CN
|
||||
///
|
||||
/// See [`Tlv::SslCn`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn cn(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, SslCn)
|
||||
}
|
||||
|
||||
/// SSL cipher
|
||||
///
|
||||
/// See [`Tlv::SslCipher`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn cipher(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, SslCipher)
|
||||
}
|
||||
|
||||
/// SSL signature algorithm
|
||||
///
|
||||
/// See [`Tlv::SslSigAlg`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn sig_alg(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, SslSigAlg)
|
||||
}
|
||||
|
||||
/// SSL key algorithm
|
||||
///
|
||||
/// See [`Tlv::SslKeyAlg`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn key_alg(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, SslKeyAlg)
|
||||
}
|
||||
|
||||
/// Returns an owned version of this struct
|
||||
pub(crate) fn into_owned(self) -> SslInfo<'static> {
|
||||
SslInfo(self.0, self.1, Cow::Owned(self.2.into_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for SslInfo<'_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("SslInfo")
|
||||
.field("verify", &self.verify())
|
||||
.field("client_ssl", &self.client_ssl())
|
||||
.field("client_cert_conn", &self.client_cert_conn())
|
||||
.field("client_cert_sess", &self.client_cert_sess())
|
||||
.field("fields", &self.tlvs().collect::<Vec<_>>())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
/// A PROXY protocol header
|
||||
#[derive(Default, PartialEq, Eq, Clone, Debug)]
|
||||
pub(crate) struct Header<'a>(pub(super) Option<Address>, pub(super) Cow<'a, [u8]>);
|
||||
|
||||
impl<'a> Header<'a> {
|
||||
/// Attempt to parse a PROXY protocol header from the given buffer
|
||||
///
|
||||
/// Returns the parsed header and the number of bytes consumed from the buffer. If the header
|
||||
/// is incomplete, returns [`ProxyProtocolError::BufferTooShort`] so more data can be read from
|
||||
/// the socket.
|
||||
///
|
||||
/// If the header is malformed or unsupported, returns [`ProxyProtocolError::Invalid`].
|
||||
///
|
||||
/// This function will borrow the buffer for the lifetime of the returned header. If
|
||||
/// you need to keep the header around for longer than the buffer, use
|
||||
/// [`Header::into_owned`].
|
||||
#[instrument(skip_all)]
|
||||
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), ProxyProtocolError> {
|
||||
match buf.first() {
|
||||
Some(b'P') => super::v1::decode(buf),
|
||||
Some(b'\r') => super::v2::decode(buf),
|
||||
None => Err(ProxyProtocolError::BufferTooShort),
|
||||
_ => Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
}
|
||||
|
||||
/// Proxied address information
|
||||
///
|
||||
/// If `None`, this indicates so-called "local" mode, where the connection is not proxied.
|
||||
/// This is usually the case when the connection is initiated by the proxy itself, e.g. for
|
||||
/// health checks.
|
||||
pub(crate) fn proxied_address(&self) -> Option<&Address> {
|
||||
self.0.as_ref()
|
||||
}
|
||||
|
||||
/// Iterator that yields all extension TLV (type-length-value) fields present in the header
|
||||
///
|
||||
/// See [`Tlv`] for more information on the different types of TLV fields.
|
||||
pub(crate) fn tlvs(&self) -> Tlvs<'_> {
|
||||
Tlvs { buf: &self.1 }
|
||||
}
|
||||
|
||||
// Convenience accessors for common fields
|
||||
|
||||
/// Raw ALPN extension data
|
||||
///
|
||||
/// See [`Tlv::Alpn`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn alpn(&self) -> Option<&[u8]> {
|
||||
tlv_borrowed!(self, Alpn)
|
||||
}
|
||||
|
||||
/// Authority - typically the hostname of the client (SNI)
|
||||
///
|
||||
/// See [`Tlv::Authority`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn authority(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, Authority)
|
||||
}
|
||||
|
||||
/// `CRC32c` checksum of the address information
|
||||
///
|
||||
/// See [`Tlv::Crc32c`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn crc32c(&self) -> Option<u32> {
|
||||
tlv!(self, Crc32c)
|
||||
}
|
||||
|
||||
/// Unique ID of the connection
|
||||
///
|
||||
/// See [`Tlv::UniqueId`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn unique_id(&self) -> Option<&[u8]> {
|
||||
tlv_borrowed!(self, UniqueId)
|
||||
}
|
||||
|
||||
/// SSL information
|
||||
///
|
||||
/// See [`Tlv::Ssl`] for more information.
|
||||
pub(crate) fn ssl(&self) -> Option<SslInfo<'_>> {
|
||||
tlv!(self, Ssl)
|
||||
}
|
||||
|
||||
/// Network namespace
|
||||
///
|
||||
/// See [`Tlv::Netns`] for more information.
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn netns(&self) -> Option<&str> {
|
||||
tlv_borrowed!(self, Netns)
|
||||
}
|
||||
|
||||
/// Returns an owned version of this struct
|
||||
pub(crate) fn into_owned(self) -> Header<'static> {
|
||||
Header(self.0, Cow::Owned(self.1.into_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Error)]
|
||||
pub(crate) enum ProxyProtocolError {
|
||||
#[error("The buffer is too short to contain a complete PROXY protocol header")]
|
||||
BufferTooShort,
|
||||
#[error("The PROXY protocol header is malformed")]
|
||||
Invalid,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use super::*;
|
||||
|
||||
const V1_UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
|
||||
|
||||
const V1_TCPV4: &[u8] = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
|
||||
const V1_TCPV6: &[u8] = b"PROXY TCP6 2001:db8::1 ::1 12345 443\r\n";
|
||||
|
||||
const V2_LOCAL: &[u8] =
|
||||
b"\r\n\r\n\0\r\nQUIT\n \0\0\x0f\x03\0\x04\x88\x9d\xa1\xdf \0\x05\0\0\0\0\0";
|
||||
|
||||
const V2_TCPV4: &[u8] = &[
|
||||
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 12, 127, 0, 0, 1, 192, 168, 0, 1,
|
||||
48, 57, 1, 187,
|
||||
];
|
||||
const V2_TCPV6: &[u8] = &[
|
||||
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 33, 0, 36, 32, 1, 13, 184, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 48, 57, 1, 187,
|
||||
];
|
||||
const V2_TCPV4_TLV: &[u8] = &[
|
||||
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 104, 127, 0, 0, 1, 192, 168, 0,
|
||||
1, 48, 57, 1, 187, 3, 0, 4, 211, 153, 216, 216, 5, 0, 4, 49, 50, 51, 52, 32, 0, 75, 7, 0,
|
||||
0, 0, 0, 33, 0, 7, 84, 76, 83, 118, 49, 46, 51, 34, 0, 9, 108, 111, 99, 97, 108, 104, 111,
|
||||
115, 116, 37, 0, 7, 82, 83, 65, 52, 48, 57, 54, 36, 0, 10, 82, 83, 65, 45, 83, 72, 65, 50,
|
||||
53, 54, 35, 0, 22, 84, 76, 83, 95, 65, 69, 83, 95, 50, 53, 54, 95, 71, 67, 77, 95, 83, 72,
|
||||
65, 51, 56, 52,
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn parse_proxy_header_too_short() {
|
||||
for case in [
|
||||
V1_TCPV4,
|
||||
V1_TCPV6,
|
||||
V1_UNKNOWN,
|
||||
V2_TCPV4,
|
||||
V2_TCPV6,
|
||||
V2_TCPV4_TLV,
|
||||
V2_LOCAL,
|
||||
]
|
||||
.iter()
|
||||
{
|
||||
for i in 0..case.len() {
|
||||
assert!(matches!(
|
||||
Header::parse(&case[..i]),
|
||||
Err(ProxyProtocolError::BufferTooShort)
|
||||
));
|
||||
}
|
||||
|
||||
assert!(Header::parse(case).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_proxy_header_v1_unterminated() {
|
||||
let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR";
|
||||
assert!(matches!(
|
||||
Header::parse(line),
|
||||
Err(ProxyProtocolError::Invalid)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_proxy_header_v1() {
|
||||
let (res, consumed) = Header::parse(V1_TCPV4).expect("failed to parse");
|
||||
assert_eq!(consumed, V1_TCPV4.len());
|
||||
assert_eq!(
|
||||
res.0,
|
||||
Some(Address {
|
||||
protocol: Protocol::Stream,
|
||||
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
|
||||
})
|
||||
);
|
||||
assert_eq!(res.1, vec![0; 0]);
|
||||
|
||||
let (res, consumed) = Header::parse(V1_TCPV6).expect("failed to parse");
|
||||
assert_eq!(consumed, V1_TCPV6.len());
|
||||
assert_eq!(
|
||||
res.0,
|
||||
Some(Address {
|
||||
protocol: Protocol::Stream,
|
||||
source: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
|
||||
12345
|
||||
),
|
||||
destination: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
|
||||
443
|
||||
),
|
||||
})
|
||||
);
|
||||
assert_eq!(res.1, vec![0; 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_proxy_header_v2() {
|
||||
let (res, consumed) = Header::parse(V2_LOCAL).expect("failed to parse");
|
||||
assert_eq!(consumed, V2_LOCAL.len());
|
||||
assert_eq!(res.0, None);
|
||||
|
||||
let (res, consumed) = Header::parse(V2_TCPV4).expect("failed to parse");
|
||||
assert_eq!(consumed, V2_TCPV4.len());
|
||||
assert_eq!(
|
||||
res.0,
|
||||
Some(Address {
|
||||
protocol: Protocol::Stream,
|
||||
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
|
||||
})
|
||||
);
|
||||
assert_eq!(res.1, vec![0; 0]);
|
||||
|
||||
let (res, consumed) = Header::parse(V2_TCPV6).expect("failed to parse");
|
||||
assert_eq!(consumed, V2_TCPV6.len());
|
||||
assert_eq!(
|
||||
res.0,
|
||||
Some(Address {
|
||||
protocol: Protocol::Stream,
|
||||
source: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
|
||||
12345
|
||||
),
|
||||
destination: SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
|
||||
443
|
||||
),
|
||||
})
|
||||
);
|
||||
assert_eq!(res.1, vec![0; 0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_proxy_header_v2_with_tlvs() {
|
||||
let (res, _) = Header::parse(V2_TCPV4_TLV).expect("failed to parse");
|
||||
|
||||
let mut fields = res.tlvs();
|
||||
assert_eq!(fields.next(), Some(Ok(Tlv::Crc32c(0xd399_d8d8))));
|
||||
assert_eq!(fields.next(), Some(Ok(Tlv::UniqueId(b"1234"[..].into()))));
|
||||
|
||||
let ssl = fields
|
||||
.next()
|
||||
.expect("next tlv missing")
|
||||
.expect("tlv parsing failed");
|
||||
let ssl = match ssl {
|
||||
Tlv::Ssl(ssl) => ssl,
|
||||
_ => panic!("expected SSL TLV"),
|
||||
};
|
||||
|
||||
assert_eq!(ssl.verify(), 0);
|
||||
assert!(ssl.client_ssl());
|
||||
assert!(ssl.client_cert_conn());
|
||||
assert!(ssl.client_cert_sess());
|
||||
|
||||
let mut f = ssl.tlvs();
|
||||
|
||||
assert_eq!(f.next(), Some(Ok(Tlv::SslVersion("TLSv1.3".into()))));
|
||||
assert_eq!(f.next(), Some(Ok(Tlv::SslCn("localhost".into()))));
|
||||
assert_eq!(f.next(), Some(Ok(Tlv::SslKeyAlg("RSA4096".into()))));
|
||||
assert_eq!(f.next(), Some(Ok(Tlv::SslSigAlg("RSA-SHA256".into()))));
|
||||
assert_eq!(
|
||||
f.next(),
|
||||
Some(Ok(Tlv::SslCipher("TLS_AES_256_GCM_SHA384".into())))
|
||||
);
|
||||
assert!(f.next().is_none());
|
||||
|
||||
assert!(fields.next().is_none());
|
||||
}
|
||||
}
|
||||
243
src/tokio/proxy_protocol/mod.rs
Normal file
243
src/tokio/proxy_protocol/mod.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use std::{
|
||||
cmp::min,
|
||||
io,
|
||||
io::IoSlice,
|
||||
ops::Deref,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use eyre::{Result, eyre};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::tokio::proxy_protocol::header::{Header, ProxyProtocolError};
|
||||
|
||||
pub(crate) mod header;
|
||||
mod utils;
|
||||
mod v1;
|
||||
mod v2;
|
||||
|
||||
// Length of the read buffer
|
||||
const READ_BUFFER_LEN: usize = 536;
|
||||
|
||||
pin_project! {
|
||||
pub struct ProxyProtocolStream<S> {
|
||||
#[pin]
|
||||
stream: S,
|
||||
remaining: Vec<u8>,
|
||||
header: Option<Header<'static>>,
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyProtocolStream<S> {
|
||||
pub(crate) fn header(&self) -> Option<&Header<'static>> {
|
||||
self.header.as_ref()
|
||||
}
|
||||
|
||||
fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
|
||||
self.project().stream
|
||||
}
|
||||
|
||||
#[expect(unused, reason = "Left here if we extract this to a public library")]
|
||||
pub(crate) fn try_into_stream(self) -> Result<S> {
|
||||
if self.remaining.is_empty() {
|
||||
Ok(self.stream)
|
||||
} else {
|
||||
Err(eyre!(
|
||||
"Cannot return inner stream because buffer is not empty"
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsRef<S> for ProxyProtocolStream<S> {
|
||||
fn as_ref(&self) -> &S {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Deref for ProxyProtocolStream<S> {
|
||||
type Target = S;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.stream
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> ProxyProtocolStream<S>
|
||||
where S: AsyncRead + Unpin
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn new(mut stream: S) -> Result<Self, io::Error> {
|
||||
let mut remaining = Vec::with_capacity(READ_BUFFER_LEN);
|
||||
|
||||
loop {
|
||||
let bytes_read = stream.read_buf(&mut remaining).await?;
|
||||
if bytes_read == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::UnexpectedEof,
|
||||
"end of stream",
|
||||
));
|
||||
}
|
||||
|
||||
match Header::parse(&remaining) {
|
||||
Ok((header, consumed)) => {
|
||||
let header = header.into_owned();
|
||||
remaining.drain(..consumed);
|
||||
|
||||
return Ok(Self {
|
||||
stream,
|
||||
remaining,
|
||||
header: Some(header),
|
||||
});
|
||||
}
|
||||
Err(ProxyProtocolError::BufferTooShort) => {}
|
||||
// Something went wrong parsing the PROXY protocol. We assume that we weren't meant
|
||||
// to parse it, and that this is just a regular stream without the PROXY protocol.
|
||||
Err(_) => {
|
||||
return Ok(Self {
|
||||
stream,
|
||||
remaining,
|
||||
header: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyProtocolStream<S>
|
||||
where S: AsyncRead
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.project();
|
||||
|
||||
if !this.remaining.is_empty() {
|
||||
let to_copy = min(this.remaining.len(), buf.remaining());
|
||||
|
||||
buf.put_slice(&this.remaining[..to_copy]);
|
||||
this.remaining.drain(..to_copy);
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
this.stream.poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncBufRead for ProxyProtocolStream<S>
|
||||
where S: AsyncBufRead
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
let this = self.project();
|
||||
|
||||
if !this.remaining.is_empty() {
|
||||
return Poll::Ready(Ok(&this.remaining[..]));
|
||||
}
|
||||
|
||||
this.stream.poll_fill_buf(cx)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn consume(self: Pin<&mut Self>, amt: usize) {
|
||||
let this = self.project();
|
||||
|
||||
if this.remaining.is_empty() {
|
||||
this.stream.consume(amt);
|
||||
} else {
|
||||
let len = this.remaining.len();
|
||||
if amt <= len {
|
||||
this.remaining.drain(..amt);
|
||||
} else {
|
||||
this.remaining.drain(..len);
|
||||
this.stream.consume(amt - len);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for ProxyProtocolStream<S>
|
||||
where S: AsyncWrite
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.get_pin_mut().poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.get_pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.get_pin_mut().poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
self.get_pin_mut().poll_write_vectored(cx, bufs)
|
||||
}
|
||||
|
||||
fn is_write_vectored(&self) -> bool {
|
||||
self.stream.is_write_vectored()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
io::Cursor,
|
||||
net::{IpAddr, Ipv4Addr, SocketAddr},
|
||||
};
|
||||
|
||||
use super::{
|
||||
header::{Address, Protocol},
|
||||
*,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse() {
|
||||
let mut buf = [0; 1024];
|
||||
let header = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
|
||||
buf[..header.len()].copy_from_slice(header);
|
||||
buf[header.len()..].fill(255);
|
||||
|
||||
let mut stream = Cursor::new(&buf);
|
||||
|
||||
let mut proxied = ProxyProtocolStream::new(&mut stream)
|
||||
.await
|
||||
.expect("failed to create stream");
|
||||
assert_eq!(
|
||||
proxied.header(),
|
||||
Some(Header(
|
||||
Some(Address {
|
||||
protocol: Protocol::Stream,
|
||||
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
|
||||
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
|
||||
}),
|
||||
vec![0; 0].into(),
|
||||
))
|
||||
.as_ref()
|
||||
);
|
||||
|
||||
let mut buf = Vec::new();
|
||||
AsyncReadExt::read_to_end(&mut proxied, &mut buf)
|
||||
.await
|
||||
.expect("failed to read from stream");
|
||||
assert_eq!(buf.len(), 1024 - header.len());
|
||||
assert!(buf.into_iter().all(|b| b == 255));
|
||||
}
|
||||
}
|
||||
39
src/tokio/proxy_protocol/utils.rs
Normal file
39
src/tokio/proxy_protocol/utils.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
use std::{
|
||||
net::{IpAddr, Ipv4Addr, Ipv6Addr},
|
||||
str::FromStr,
|
||||
};
|
||||
|
||||
pub(super) fn read_until(buf: &[u8], delim: u8) -> Option<&[u8]> {
|
||||
for i in 0..buf.len() {
|
||||
if buf[i] == delim {
|
||||
return Some(&buf[..i]);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub(super) trait AddressFamily: FromStr + Into<IpAddr> {
|
||||
const BYTES: usize;
|
||||
|
||||
fn from_slice(slice: &[u8]) -> Self;
|
||||
}
|
||||
|
||||
impl AddressFamily for Ipv4Addr {
|
||||
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
|
||||
const BYTES: usize = (Self::BITS / 8) as usize;
|
||||
|
||||
fn from_slice(slice: &[u8]) -> Self {
|
||||
let arr: [u8; Self::BYTES] = slice.try_into().expect("slice must be 4 bytes");
|
||||
arr.into()
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressFamily for Ipv6Addr {
|
||||
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
|
||||
const BYTES: usize = (Self::BITS / 8) as usize;
|
||||
|
||||
fn from_slice(slice: &[u8]) -> Self {
|
||||
let arr: [u8; Self::BYTES] = slice.try_into().expect("slice must be 16 bytes");
|
||||
arr.into()
|
||||
}
|
||||
}
|
||||
110
src/tokio/proxy_protocol/v1.rs
Normal file
110
src/tokio/proxy_protocol/v1.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
str::{FromStr as _, from_utf8},
|
||||
};
|
||||
|
||||
use super::{
|
||||
header::{Address, Header, Protocol, ProxyProtocolError},
|
||||
utils::{AddressFamily, read_until},
|
||||
};
|
||||
|
||||
const MAX_LENGTH: usize = 107;
|
||||
const GREETING: &[u8] = b"PROXY";
|
||||
const UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
|
||||
// All other valid PROXY headers are longer than this
|
||||
const MIN_LENGTH: usize = UNKNOWN.len();
|
||||
|
||||
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, ProxyProtocolError> {
|
||||
let Some(address) = read_until(&buf[*pos..], b' ') else {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
|
||||
let addr = from_utf8(address)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)
|
||||
.and_then(|s| A::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
|
||||
*pos += address.len() + 1;
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, ProxyProtocolError> {
|
||||
let Some(port) = read_until(&buf[*pos..], delim) else {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
|
||||
let p = from_utf8(port)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)
|
||||
.and_then(|s| u16::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
|
||||
*pos += port.len() + 1;
|
||||
|
||||
Ok(p)
|
||||
}
|
||||
|
||||
fn parse_addrs<A: AddressFamily>(
|
||||
buf: &[u8],
|
||||
pos: &mut usize,
|
||||
) -> Result<Address, ProxyProtocolError> {
|
||||
let src_addr: A = parse_addr(buf, pos)?;
|
||||
let dst_addr: A = parse_addr(buf, pos)?;
|
||||
let src_port = parse_port(buf, pos, b' ')?;
|
||||
let dst_port = parse_port(buf, pos, b'\r')?;
|
||||
|
||||
Ok(Address {
|
||||
protocol: Protocol::Stream, // v1 only supports TCP
|
||||
source: SocketAddr::new(src_addr.into(), src_port),
|
||||
destination: SocketAddr::new(dst_addr.into(), dst_port),
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
let mut pos = 0;
|
||||
|
||||
if buf.len() < MIN_LENGTH {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if !buf.starts_with(GREETING) {
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
pos += GREETING.len() + 1;
|
||||
|
||||
let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
|
||||
let Some(rest) = read_until(&buf[pos..], b'\r') else {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
pos += rest.len() + 1;
|
||||
|
||||
None
|
||||
} else {
|
||||
let proto = &buf[pos..pos + 5];
|
||||
pos += 5;
|
||||
|
||||
match proto {
|
||||
b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
|
||||
b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
};
|
||||
|
||||
match buf.get(pos) {
|
||||
Some(b'\n') => pos += 1,
|
||||
None => return Err(ProxyProtocolError::BufferTooShort),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
|
||||
Ok((Header(addrs, Cow::default()), pos))
|
||||
}
|
||||
|
||||
/// Decode a version 1 PROXY header from a buffer.
|
||||
///
|
||||
/// Returns the decoded header and the number of bytes consumed from the buffer.
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
// Guard against a malicious client sending a very long header, since it is a
|
||||
// delimited protocol.
|
||||
match decode_inner(buf) {
|
||||
Err(ProxyProtocolError::BufferTooShort) if buf.len() >= MAX_LENGTH => {
|
||||
Err(ProxyProtocolError::Invalid)
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
122
src/tokio/proxy_protocol/v2.rs
Normal file
122
src/tokio/proxy_protocol/v2.rs
Normal file
@@ -0,0 +1,122 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
};
|
||||
|
||||
use super::{
|
||||
header::{Address, Header, Protocol, ProxyProtocolError},
|
||||
utils::AddressFamily,
|
||||
};
|
||||
|
||||
const GREETING: &[u8] = b"\r\n\r\n\x00\r\nQUIT\n";
|
||||
const MIN_LENGTH: usize = GREETING.len() + 4;
|
||||
const AF_UNIX_ADDRS_LEN: usize = 216;
|
||||
|
||||
fn parse_addrs<T: AddressFamily>(
|
||||
buf: &[u8],
|
||||
pos: &mut usize,
|
||||
rest: &mut usize,
|
||||
protocol: Protocol,
|
||||
) -> Result<Address, ProxyProtocolError> {
|
||||
if buf.len() < *pos + T::BYTES * 2 + 4 {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if *rest < T::BYTES * 2 + 4 {
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
|
||||
let addr = Address {
|
||||
protocol,
|
||||
source: SocketAddr::new(
|
||||
T::from_slice(&buf[*pos..*pos + T::BYTES]).into(),
|
||||
u16::from_be_bytes([buf[*pos + T::BYTES * 2], buf[*pos + T::BYTES * 2 + 1]]),
|
||||
),
|
||||
destination: SocketAddr::new(
|
||||
T::from_slice(&buf[*pos + T::BYTES..*pos + T::BYTES * 2]).into(),
|
||||
u16::from_be_bytes([buf[*pos + T::BYTES * 2 + 2], buf[*pos + T::BYTES * 2 + 3]]),
|
||||
),
|
||||
};
|
||||
|
||||
*rest -= T::BYTES * 2 + 4;
|
||||
*pos += T::BYTES * 2 + 4;
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
/// Decode a version 2 PROXY header from a buffer.
|
||||
///
|
||||
/// Returns the decoded header and the number of bytes consumed from the buffer.
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
let mut pos = 0;
|
||||
|
||||
if buf.len() < MIN_LENGTH {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if !buf.starts_with(GREETING) {
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
pos += GREETING.len();
|
||||
|
||||
let is_local = match buf[pos] {
|
||||
0x20 => true,
|
||||
0x21 => false,
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
};
|
||||
let protocol = buf[pos + 1];
|
||||
let mut rest: usize = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]).into();
|
||||
pos += 4;
|
||||
|
||||
if buf.len() < pos + rest {
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
|
||||
let addr_info = match protocol {
|
||||
0x00 => None,
|
||||
0x11 => Some(parse_addrs::<Ipv4Addr>(
|
||||
buf,
|
||||
&mut pos,
|
||||
&mut rest,
|
||||
Protocol::Stream,
|
||||
)?),
|
||||
0x12 => Some(parse_addrs::<Ipv4Addr>(
|
||||
buf,
|
||||
&mut pos,
|
||||
&mut rest,
|
||||
Protocol::Datagram,
|
||||
)?),
|
||||
0x21 => Some(parse_addrs::<Ipv6Addr>(
|
||||
buf,
|
||||
&mut pos,
|
||||
&mut rest,
|
||||
Protocol::Stream,
|
||||
)?),
|
||||
0x22 => Some(parse_addrs::<Ipv6Addr>(
|
||||
buf,
|
||||
&mut pos,
|
||||
&mut rest,
|
||||
Protocol::Datagram,
|
||||
)?),
|
||||
0x31 | 0x32 => {
|
||||
// AF_UNIX - we don't parse it, but don't reject it either in case we need the TLVs
|
||||
if rest < AF_UNIX_ADDRS_LEN {
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
rest -= AF_UNIX_ADDRS_LEN;
|
||||
pos += AF_UNIX_ADDRS_LEN;
|
||||
None
|
||||
}
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
};
|
||||
|
||||
let tlv_data = Cow::Borrowed(&buf[pos..pos + rest]);
|
||||
|
||||
pos += rest;
|
||||
|
||||
let header = if is_local {
|
||||
Header(None, tlv_data)
|
||||
} else {
|
||||
Header(addr_info, tlv_data)
|
||||
};
|
||||
|
||||
Ok((header, pos))
|
||||
}
|
||||
123
src/tracing.rs
Normal file
123
src/tracing.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
use eyre::Result;
|
||||
use tracing_error::ErrorLayer;
|
||||
use tracing_subscriber::{filter::EnvFilter, fmt, prelude::*};
|
||||
|
||||
use crate::config;
|
||||
|
||||
pub(super) fn install() -> Result<()> {
|
||||
let config = config::get();
|
||||
|
||||
let mut filter_layer = EnvFilter::builder()
|
||||
.with_default_directive(config.log_level.parse()?)
|
||||
.parse(&config.log_level)?;
|
||||
for (k, v) in &config.log.rust_log {
|
||||
filter_layer = filter_layer.add_directive(format!("{k}={v}").parse()?);
|
||||
}
|
||||
|
||||
if config.debug {
|
||||
let console_layer = console_subscriber::ConsoleLayer::builder()
|
||||
.server_addr(config.listen.debug)
|
||||
.spawn();
|
||||
tracing_subscriber::registry()
|
||||
.with(ErrorLayer::default())
|
||||
.with(console_layer)
|
||||
.with(
|
||||
fmt::layer()
|
||||
.compact()
|
||||
.event_format(
|
||||
fmt::format()
|
||||
.with_thread_ids(true)
|
||||
.with_thread_names(true)
|
||||
.with_source_location(true)
|
||||
.compact(),
|
||||
)
|
||||
.with_writer(std::io::stderr)
|
||||
.with_filter(filter_layer),
|
||||
)
|
||||
.with(::sentry::integrations::tracing::layer())
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::registry()
|
||||
.with(ErrorLayer::default())
|
||||
.with(json::layer().with_filter(filter_layer))
|
||||
.with(::sentry::integrations::tracing::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(super) fn install_crude() -> tracing::dispatcher::DefaultGuard {
|
||||
let filter_layer = EnvFilter::builder()
|
||||
.parse("trace,console_subscriber=info,runtime=info,tokio=info,tungstenite=info")
|
||||
.expect("infallible");
|
||||
let subscriber = tracing_subscriber::registry()
|
||||
.with(ErrorLayer::default())
|
||||
.with(filter_layer)
|
||||
.with(json::layer());
|
||||
tracing::dispatcher::set_default(&subscriber.into())
|
||||
}
|
||||
|
||||
mod json {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tracing::Subscriber;
|
||||
use tracing_subscriber::{layer::Layer, registry::LookupSpan};
|
||||
|
||||
pub(super) fn layer<S>() -> impl Layer<S>
|
||||
where S: Subscriber + for<'lookup> LookupSpan<'lookup> {
|
||||
let mut json_layer = json_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.flatten_event(true)
|
||||
.flatten_current_span_on_top_level(true);
|
||||
|
||||
let inner_layer = json_layer.inner_layer_mut();
|
||||
inner_layer.with_thread_ids("thread_id");
|
||||
inner_layer.with_thread_names("thread_name");
|
||||
inner_layer.add_dynamic_field("pid", |_, _| {
|
||||
Some(serde_json::Value::Number(serde_json::Number::from(
|
||||
std::process::id(),
|
||||
)))
|
||||
});
|
||||
inner_layer.with_flattened_event_with_renames(
|
||||
move |name, map| match map.get(name) {
|
||||
Some(name) => name.as_str(),
|
||||
None => name,
|
||||
},
|
||||
HashMap::from([("message".to_owned(), "event".to_owned())]),
|
||||
);
|
||||
|
||||
json_layer
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) mod sentry {
|
||||
use std::str::FromStr as _;
|
||||
|
||||
use tracing::trace;
|
||||
|
||||
use crate::{VERSION, authentik_user_agent, config};
|
||||
|
||||
pub(crate) fn install() -> sentry::ClientInitGuard {
|
||||
trace!("setting up sentry");
|
||||
let config = config::get();
|
||||
sentry::init(sentry::ClientOptions {
|
||||
dsn: config.error_reporting.sentry_dsn.clone().map(|dsn| {
|
||||
sentry::types::Dsn::from_str(&dsn).expect("Failed to create sentry DSN")
|
||||
}),
|
||||
release: Some(format!("authentik@{VERSION}").into()),
|
||||
environment: Some(config.error_reporting.environment.clone().into()),
|
||||
attach_stacktrace: true,
|
||||
send_default_pii: config.error_reporting.send_pii,
|
||||
sample_rate: config.error_reporting.sample_rate,
|
||||
traces_sample_rate: if config.debug {
|
||||
1.0
|
||||
} else {
|
||||
config.error_reporting.sample_rate
|
||||
},
|
||||
user_agent: authentik_user_agent().into(),
|
||||
..sentry::ClientOptions::default()
|
||||
})
|
||||
}
|
||||
}
|
||||
366
src/worker.rs
Normal file
366
src/worker.rs
Normal file
@@ -0,0 +1,366 @@
|
||||
use std::{
|
||||
env::temp_dir,
|
||||
os::unix,
|
||||
path::PathBuf,
|
||||
process::Stdio,
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicBool, Ordering},
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use argh::FromArgs;
|
||||
use axum::body::Body;
|
||||
use eyre::{Result, eyre};
|
||||
use hyper_unix_socket::UnixSocketConnector;
|
||||
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
|
||||
use nix::{
|
||||
sys::signal::{Signal, kill},
|
||||
unistd::Pid,
|
||||
};
|
||||
use tokio::{
|
||||
net::UnixStream,
|
||||
process::{Child, Command},
|
||||
signal::unix::SignalKind,
|
||||
sync::{Mutex, broadcast::error::RecvError},
|
||||
};
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
use crate::{
|
||||
arbiter::{Arbiter, Tasks},
|
||||
axum::server,
|
||||
config,
|
||||
mode::Mode,
|
||||
};
|
||||
|
||||
#[derive(Debug, Default, FromArgs, PartialEq)]
|
||||
/// Run the authentik worker.
|
||||
#[argh(subcommand, name = "worker")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(crate) struct Cli {}
|
||||
|
||||
const INITIAL_WORKER_ID: usize = 1000;
|
||||
static INITIAL_WORKER_READY: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
struct Worker(Child);
|
||||
|
||||
impl Worker {
|
||||
fn new(worker_id: usize, socket_path: Option<&str>) -> Result<Self> {
|
||||
info!(worker_id, "Starting worker");
|
||||
let mut cmd = Command::new("python");
|
||||
cmd.args(["-m", "lifecycle.worker_process", &worker_id.to_string()]);
|
||||
if let Some(socket_path) = socket_path {
|
||||
cmd.arg(socket_path);
|
||||
}
|
||||
Ok(Self(
|
||||
cmd.kill_on_drop(true)
|
||||
.stdout(Stdio::inherit())
|
||||
.stderr(Stdio::inherit())
|
||||
.spawn()?,
|
||||
))
|
||||
}
|
||||
|
||||
async fn shutdown(&mut self, signal: Signal) -> Result<()> {
|
||||
trace!(
|
||||
signal = signal.as_str(),
|
||||
"sending shutdown signal to worker"
|
||||
);
|
||||
if let Some(id) = self.0.id() {
|
||||
kill(Pid::from_raw(id.cast_signed()), signal)?;
|
||||
}
|
||||
self.0.wait().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn graceful_shutdown(&mut self) -> Result<()> {
|
||||
info!("gracefully shutting down worker");
|
||||
self.shutdown(Signal::SIGTERM).await
|
||||
}
|
||||
|
||||
async fn fast_shutdown(&mut self) -> Result<()> {
|
||||
info!("immediately shutting down worker");
|
||||
self.shutdown(Signal::SIGINT).await
|
||||
}
|
||||
|
||||
fn is_alive(&mut self) -> bool {
|
||||
let try_wait = self.0.try_wait();
|
||||
match try_wait {
|
||||
Ok(Some(code)) => {
|
||||
warn!("worker has exited with status {code}");
|
||||
false
|
||||
}
|
||||
Ok(None) => true,
|
||||
Err(err) => {
|
||||
warn!("failed to check the status of worker process, ignoring: {err}");
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct Workers {
|
||||
workers: Mutex<Vec<Worker>>,
|
||||
socket_path: PathBuf,
|
||||
pub(crate) client: Client<UnixSocketConnector<PathBuf>, Body>,
|
||||
}
|
||||
|
||||
impl Workers {
|
||||
fn new(socket_path: PathBuf) -> Result<Self> {
|
||||
let mut workers = Vec::with_capacity(config::get().worker.processes.get());
|
||||
workers.push(Worker::new(
|
||||
INITIAL_WORKER_ID,
|
||||
Some(&format!("{}", &socket_path.display())),
|
||||
)?);
|
||||
|
||||
let client = Client::builder(TokioExecutor::new())
|
||||
.pool_idle_timeout(Duration::from_secs(60))
|
||||
.set_host(false)
|
||||
.build(UnixSocketConnector::new(socket_path.clone()));
|
||||
|
||||
Ok(Self {
|
||||
workers: Mutex::new(workers),
|
||||
socket_path,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
async fn start_other_workers(&self) -> Result<()> {
|
||||
for i in 1..config::get().worker.processes.get() {
|
||||
self.workers
|
||||
.lock()
|
||||
.await
|
||||
.push(Worker::new(INITIAL_WORKER_ID + i, None)?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn graceful_shutdown(&self) -> Result<()> {
|
||||
let mut results = Vec::with_capacity(self.workers.lock().await.capacity());
|
||||
for worker in self.workers.lock().await.iter_mut() {
|
||||
results.push(worker.graceful_shutdown().await);
|
||||
}
|
||||
|
||||
results.into_iter().find(Result::is_err).unwrap_or(Ok(()))
|
||||
}
|
||||
|
||||
async fn fast_shutdown(&self) -> Result<()> {
|
||||
let mut results = Vec::with_capacity(self.workers.lock().await.capacity());
|
||||
for worker in self.workers.lock().await.iter_mut() {
|
||||
results.push(worker.fast_shutdown().await);
|
||||
}
|
||||
|
||||
results.into_iter().find(Result::is_err).unwrap_or(Ok(()))
|
||||
}
|
||||
|
||||
pub(crate) async fn are_alive(&self) -> bool {
|
||||
for worker in self.workers.lock().await.iter_mut() {
|
||||
if !worker.is_alive() {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
async fn is_socket_ready(&self) -> bool {
|
||||
let result = UnixStream::connect(&self.socket_path).await;
|
||||
trace!("checking if worker socket is ready: {result:?}");
|
||||
result.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
async fn watch_workers(arbiter: Arbiter, workers: Arc<Workers>) -> Result<()> {
|
||||
info!("starting worker watcher");
|
||||
let mut signals_rx = arbiter.signals_subscribe();
|
||||
loop {
|
||||
tokio::select! {
|
||||
signal = signals_rx.recv() => {
|
||||
match signal {
|
||||
Ok(signal) => {
|
||||
if signal == SignalKind::user_defined2() {
|
||||
info!("worker notified us ready, marked ready for operation");
|
||||
INITIAL_WORKER_READY.store(true, Ordering::Relaxed);
|
||||
workers.start_other_workers().await?;
|
||||
}
|
||||
},
|
||||
Err(RecvError::Lagged(_)) => {},
|
||||
Err(RecvError::Closed) => {
|
||||
warn!("error receiving signals");
|
||||
return Err(RecvError::Closed.into());
|
||||
}
|
||||
}
|
||||
},
|
||||
() = tokio::time::sleep(Duration::from_secs(1)), if !INITIAL_WORKER_READY.load(Ordering::Relaxed) => {
|
||||
// On some platforms the SIGUSR1 can be missed.
|
||||
// Fall back to probing the worker unix socket and mark ready once it accepts connections.
|
||||
if workers.is_socket_ready().await {
|
||||
info!("worker socket is accepting connections, marked ready for operation");
|
||||
INITIAL_WORKER_READY.store(true, Ordering::Relaxed);
|
||||
workers.start_other_workers().await?;
|
||||
}
|
||||
},
|
||||
() = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
if !workers.are_alive().await {
|
||||
return Err(eyre!("gunicorn has exited unexpectedly"));
|
||||
}
|
||||
},
|
||||
() = arbiter.fast_shutdown() => {
|
||||
workers.fast_shutdown().await?;
|
||||
return Ok(());
|
||||
},
|
||||
() = arbiter.graceful_shutdown() => {
|
||||
workers.graceful_shutdown().await?;
|
||||
return Ok(());
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mod healthcheck {
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
Router,
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{StatusCode, header::HOST},
|
||||
response::IntoResponse,
|
||||
routing::any,
|
||||
};
|
||||
|
||||
use crate::{axum::router::wrap_router, db, worker::Workers};
|
||||
|
||||
async fn health_ready(State(workers): State<Arc<Workers>>) -> impl IntoResponse {
|
||||
if !workers.are_alive().await || sqlx::query("SELECT 1").execute(db::get()).await.is_err() {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
} else {
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("http://localhost:8000/-/health/ready/")
|
||||
.header(HOST, "localhost")
|
||||
.body(Body::from(""));
|
||||
if let Ok(req) = req
|
||||
&& let Ok(res) = workers.client.request(req).await
|
||||
{
|
||||
res.status()
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_live(State(workers): State<Arc<Workers>>) -> impl IntoResponse {
|
||||
let req = Request::builder()
|
||||
.method("GET")
|
||||
.uri("http://localhost:8000/-/health/live/")
|
||||
.header(HOST, "localhost")
|
||||
.body(Body::from(""));
|
||||
if let Ok(req) = req
|
||||
&& let Ok(res) = workers.client.request(req).await
|
||||
{
|
||||
res.status()
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
async fn fallback() -> impl IntoResponse {
|
||||
StatusCode::OK
|
||||
}
|
||||
|
||||
pub(super) fn build_router(workers: Arc<Workers>) -> Router {
|
||||
wrap_router(
|
||||
Router::new()
|
||||
.route("/-/heath/ready/", any(health_ready))
|
||||
.route("/-/heath/live/", any(health_live))
|
||||
.fallback(fallback)
|
||||
.with_state(workers),
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
mod worker_status {
|
||||
use std::time::Duration;
|
||||
|
||||
use eyre::Result;
|
||||
use nix::unistd::gethostname;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{arbiter::Arbiter, authentik_full_version, db};
|
||||
|
||||
async fn keep(arbiter: Arbiter, id: Uuid, hostname: &str, version: &str) -> Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(30)) => {
|
||||
sqlx::query("
|
||||
INSERT INTO authentik_tasks_workerstatus (id, hostname, version, last_seen)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (id) DO UPDATE SET last_seen = NOW()
|
||||
")
|
||||
.bind(id)
|
||||
.bind(hostname)
|
||||
.bind(version)
|
||||
.execute(db::get())
|
||||
.await?;
|
||||
},
|
||||
() = arbiter.shutdown() => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn run(arbiter: Arbiter) -> Result<()> {
|
||||
let id = Uuid::new_v4();
|
||||
let raw_hostname = gethostname()?;
|
||||
let hostname = raw_hostname.to_string_lossy();
|
||||
let version = authentik_full_version();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = keep(arbiter.clone(), id, hostname.as_ref(), &version) => {
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(10)) => {},
|
||||
() = arbiter.shutdown() => return Ok(()),
|
||||
}
|
||||
},
|
||||
() = arbiter.shutdown() => return Ok(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Workers>> {
|
||||
let arbiter = tasks.arbiter();
|
||||
|
||||
let workers = Arc::new(Workers::new(temp_dir().join("authentik-worker.sock"))?);
|
||||
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_workers", module_path!()))
|
||||
.spawn(watch_workers(arbiter.clone(), Arc::clone(&workers)))?;
|
||||
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::worker_status::run", module_path!()))
|
||||
.spawn(worker_status::run(arbiter))?;
|
||||
|
||||
if Mode::get() == Mode::Worker {
|
||||
let router = healthcheck::build_router(Arc::clone(&workers));
|
||||
|
||||
for addr in config::get().listen.http.iter().copied() {
|
||||
server::start_plain(tasks, "worker", router.clone(), addr)?;
|
||||
}
|
||||
|
||||
server::start_unix(
|
||||
tasks,
|
||||
"worker",
|
||||
router,
|
||||
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(workers)
|
||||
}
|
||||
@@ -6,17 +6,58 @@ title: Reverse proxy
|
||||
Since authentik uses WebSockets to communicate with Outposts, it does not support HTTP/1.0 reverse proxies. The HTTP/1.0 specification does not officially support WebSockets or protocol upgrades, though some clients may allow it.
|
||||
:::
|
||||
|
||||
If you want to access authentik behind a reverse proxy, there are a few headers that must be passed upstream:
|
||||
|
||||
- `X-Forwarded-Proto`: Tells authentik and Proxy Providers if they are being served over an HTTPS connection.
|
||||
- `X-Forwarded-For`: Without this, authentik will not know the IP addresses of clients.
|
||||
- `Host`: Required for various security checks, WebSocket handshake, and Outpost and Proxy Provider communication.
|
||||
- `Connection: Upgrade` and `Upgrade: WebSocket`: Required to upgrade protocols for requests to the WebSocket endpoints under HTTP/1.1.
|
||||
If you want to access authentik behind a reverse proxy, there are a few headers that must be passed upstream for authentik to be able to correctly identify a connection.
|
||||
|
||||
It is also recommended to use a [modern TLS configuration](https://ssl-config.mozilla.org/) and disable SSL/TLS protocols older than TLS 1.3.
|
||||
|
||||
If your reverse proxy isn't accessing authentik from a private IP address, [trusted proxy CIDRs configuration](./configuration/configuration.mdx#listen-settings) needs to be set on the authentik server to allow client IP address detection.
|
||||
|
||||
### Scheme
|
||||
|
||||
authentik and Proxy Providers need to know if they are being served over an HTTPS connection.
|
||||
|
||||
The connection scheme (HTTP/HTTPS) is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
|
||||
|
||||
- `X-Forwarded-Proto` header,
|
||||
- `X-Forwarded-Scheme` header,
|
||||
- `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239). If multiple `proto=` stanzas are present, only the first one is retained.
|
||||
- whether the connection was made via TLS to the proxy via the PROXY protocol if used.
|
||||
|
||||
If the connection is not trusted, or the above is missing, authentik will look at whether the connection was made over plaintext or TLS.
|
||||
|
||||
### Host
|
||||
|
||||
Required for various security checks, WebSocket handshake, and Outpost and Proxy Provider communication.
|
||||
|
||||
The Host is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
|
||||
|
||||
- `X-Forwarded-Host` header,
|
||||
- `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239). If multiple `host=` stanzas are present, only the first one is retained.
|
||||
|
||||
If the connection is not trusted, or the above is missing, authentik will consider the following:
|
||||
|
||||
- `Host` header,
|
||||
- host part of the request URL.
|
||||
|
||||
### Client IP
|
||||
|
||||
authentik needs to know the IP addresses of clients for various security features and for audit purposes.
|
||||
|
||||
The client IP is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
|
||||
|
||||
- the rightmost IP in the `X-Forwarded-For` header,
|
||||
- `X-Real-IP` header,
|
||||
- the rightmost IP in the `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239),
|
||||
- the IP passed via PROXY Protocol if used.
|
||||
|
||||
If the connection is not trusted, the client IP will be extracted from the TCP metadata.
|
||||
|
||||
### WebSockets
|
||||
|
||||
The `Connection: Upgrade` and `Upgrade: WebSocket` headers are required to upgrade protocols for requests to the WebSocket endpoints under HTTP/1.1.
|
||||
|
||||
### Example configuration
|
||||
|
||||
The following nginx configuration can be used as a starting point for your own configuration.
|
||||
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user