Compare commits

...

178 Commits

Author SHA1 Message Date
Marc 'risson' Schmitt
04c066d8b0 lint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 19:20:34 +01:00
Marc 'risson' Schmitt
f3341a4b83 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 18:46:47 +01:00
Marc 'risson' Schmitt
27f652dcf3 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 16:22:45 +01:00
Marc 'risson' Schmitt
dca2c2f536 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 16:10:43 +01:00
Marc 'risson' Schmitt
5d426411dd wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:59:31 +01:00
Marc 'risson' Schmitt
35ec2ea930 remove useless change
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:58:09 +01:00
Marc 'risson' Schmitt
b7c4d04c16 Merge remote-tracking branch 'origin/multiple-listeners' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:56:52 +01:00
Marc 'risson' Schmitt
8ef1b945e8 lint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:47:44 +01:00
Marc 'risson' Schmitt
7fab5b6e93 Merge branch 'main' into multiple-listeners 2026-03-19 14:23:47 +00:00
Marc 'risson' Schmitt
7468a7271c Update website/docs/releases/2026/v2026.5.md
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:23:30 +01:00
Marc 'risson' Schmitt
1a270f9c6e Apply suggestions from code review
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:23:01 +01:00
Marc 'risson' Schmitt
3ae126cd99 Update website/docs/install-config/configuration/configuration.mdx
Co-authored-by: Tana M Berry <tanamarieberry@yahoo.com>
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 15:22:42 +01:00
Marc 'risson' Schmitt
6db2fbc8aa wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 14:38:19 +01:00
Marc 'risson' Schmitt
32f6738a40 errgroup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 14:33:42 +01:00
Marc 'risson' Schmitt
1ddc596362 better unix listener
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-19 14:22:41 +01:00
Marc 'risson' Schmitt
1281371077 Merge branch 'main' into rust-server 2026-03-18 15:29:43 +01:00
Marc 'risson' Schmitt
58508ebc4e Merge branch 'main' into rust-server 2026-03-18 15:29:25 +01:00
Marc 'risson' Schmitt
aa614ad31c lint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-18 15:22:47 +01:00
Marc 'risson' Schmitt
b9b1c7ccf6 metrics socket and healthchecks for all outposts
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-18 15:17:33 +01:00
Marc 'risson' Schmitt
f8209680fa server healthcheck and unix socket
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-18 14:58:55 +01:00
Marc 'risson' Schmitt
2b2c6a3b9b Merge branch 'main' into multiple-listeners 2026-03-18 13:43:09 +01:00
Marc 'risson' Schmitt
62644a79fd root: allow listening on multiple IPs
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-16 18:40:45 +01:00
Marc 'risson' Schmitt
c426c94a25 Merge branch 'main' into rust-server 2026-03-16 14:09:37 +01:00
Marc 'risson' Schmitt
2e04738306 use waitgroups for multiple servers, TODO: fix healthchecks
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-13 16:13:56 +01:00
Marc 'risson' Schmitt
297e8db6eb use waitgroups for multiple servers
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-13 16:03:45 +01:00
Marc 'risson' Schmitt
5b9a30be4b Merge branch 'main' into rust-server 2026-03-13 15:59:30 +01:00
Marc 'risson' Schmitt
457429f261 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-12 14:45:48 +01:00
Marc 'risson' Schmitt
a0bac73c59 go ak api controller: add support for unix urls
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-12 14:44:15 +01:00
Marc 'risson' Schmitt
b82abaf230 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-12 13:45:21 +01:00
Marc 'risson' Schmitt
c4b1e4bd44 http timeouts
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-12 13:41:50 +01:00
Marc 'risson' Schmitt
5592c4769a Merge branch 'main' into rust-server 2026-03-12 13:26:15 +01:00
Marc 'risson' Schmitt
f71f5b7278 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 19:12:03 +01:00
Marc 'risson' Schmitt
d7159cfce2 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 19:02:21 +01:00
Marc 'risson' Schmitt
30dc4e120b fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 19:01:15 +01:00
Marc 'risson' Schmitt
619023be75 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 19:00:17 +01:00
Marc 'risson' Schmitt
de63473cd2 more ci
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 18:59:23 +01:00
Marc 'risson' Schmitt
6aa50b962c rustfmt in ci
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 18:50:01 +01:00
Marc 'risson' Schmitt
f240ca1708 more ci
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 18:48:21 +01:00
Marc 'risson' Schmitt
550da2005e start ci
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 18:46:28 +01:00
Marc 'risson' Schmitt
8818a0b06c revert
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 17:54:18 +01:00
Marc 'risson' Schmitt
013190ddd0 fix tests?
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 17:50:43 +01:00
Marc 'risson' Schmitt
6fb777ae5b make server listen on unix socket
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 17:45:55 +01:00
Marc 'risson' Schmitt
41f13d8805 fix sentry tracing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 16:26:12 +01:00
Marc 'risson' Schmitt
fc5f0e7dc5 spellcheck
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 15:20:51 +01:00
Marc 'risson' Schmitt
9b9379ac8f lint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 15:09:28 +01:00
Marc 'risson' Schmitt
c4b0825dad tasks/test: remove worker health check
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 15:06:39 +01:00
Marc 'risson' Schmitt
946ace14c1 fix makefile
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-11 15:05:50 +01:00
Marc 'risson' Schmitt
6a9eb8e9c7 Merge branch 'main' into rust-server 2026-03-11 14:28:31 +01:00
Marc 'risson' Schmitt
4f0d0e72d5 disable sentry span handling, it's broken
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 18:55:59 +01:00
Marc 'risson' Schmitt
411648672e config: separate initial loading and starting the reloader
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 18:47:51 +01:00
Marc 'risson' Schmitt
d5f6d30aeb update deps
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 17:26:24 +01:00
Marc 'risson' Schmitt
1508ad0ab8 subpath
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 17:10:56 +01:00
Marc 'risson' Schmitt
892e8fd856 config fallback values
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 17:04:58 +01:00
Marc 'risson' Schmitt
d4b0ac7c14 more todos
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 16:14:19 +01:00
Marc 'risson' Schmitt
fe4857abbb db: use connection callbacks for search path and application name
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 16:13:02 +01:00
Marc 'risson' Schmitt
8b73872c0d refine sentry setup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 16:00:19 +01:00
Marc 'risson' Schmitt
d22597377a finally finish tracing
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 15:40:34 +01:00
Marc 'risson' Schmitt
58d198d60a Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-10 13:31:55 +01:00
Marc 'risson' Schmitt
1de19546d7 tracing almost done
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-09 18:00:38 +01:00
Marc 'risson' Schmitt
8ad054ce65 Merge branch 'main' into rust-server 2026-03-09 16:22:12 +01:00
Marc 'risson' Schmitt
df95fc89eb wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-06 19:53:19 +01:00
Marc 'risson' Schmitt
75898710f1 update, some logging
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-06 15:30:46 +01:00
Marc 'risson' Schmitt
3a5a0c2e4f fmt
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-06 13:45:37 +01:00
Marc 'risson' Schmitt
b806e14a00 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-06 13:44:28 +01:00
Marc 'risson' Schmitt
c2d02cd807 fix deny
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 17:47:54 +01:00
Marc 'risson' Schmitt
1212402231 remove deprecated rustls-pemfile
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 17:43:43 +01:00
Marc 'risson' Schmitt
2927f414c5 extract tracelayer
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 17:33:37 +01:00
Marc 'risson' Schmitt
5ba18fbd55 fmt
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 13:46:35 +01:00
Marc 'risson' Schmitt
1b108e40d6 Merge branch 'main' into rust-server 2026-03-05 13:46:07 +01:00
Marc 'risson' Schmitt
982ae7b261 nit
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 13:35:24 +01:00
Marc 'risson' Schmitt
294a656ad2 worker status
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-05 13:34:35 +01:00
Marc 'risson' Schmitt
dab8bab916 better handling of socket path for future testing, healthcheck
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-04 16:54:57 +01:00
Marc 'risson' Schmitt
ee1803a0ae pedantic
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-03 18:33:13 +01:00
Marc 'risson' Schmitt
99c9894a04 extract brands
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-03 16:48:25 +01:00
Marc 'risson' Schmitt
2352ce72c9 Merge branch 'main' into rust-server 2026-03-03 14:03:27 +01:00
Marc 'risson' Schmitt
bb28e6425d small fixes
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-02 18:08:57 +01:00
Marc 'risson' Schmitt
f2149dfd90 write mode to authentik-mode file
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-02 15:40:30 +01:00
Marc 'risson' Schmitt
2ff0f09db1 spawn_blocking for tls computations
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-02 15:24:46 +01:00
Marc 'risson' Schmitt
40a91fd4fb fix minor stuff
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-02 15:08:55 +01:00
Marc 'risson' Schmitt
2e3f76441c Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-03-02 15:05:46 +01:00
Marc 'risson' Schmitt
f91474dd91 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-27 18:19:54 +01:00
Marc 'risson' Schmitt
61dbd5976f some todos
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-27 18:16:25 +01:00
Marc 'risson' Schmitt
8099ac6508 some cleanup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-27 18:15:50 +01:00
Marc 'risson' Schmitt
61ed26e3f6 worker started from rust
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-27 16:54:14 +01:00
Marc 'risson' Schmitt
ea17d4cbf1 finish tls
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-27 13:58:38 +01:00
Marc 'risson' Schmitt
ac388667d0 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-26 17:14:08 +01:00
Marc 'risson' Schmitt
cdc42de5b5 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-24 18:50:08 +01:00
Marc 'risson' Schmitt
2770c3a7e0 Merge branch 'main' into rust-server 2026-02-24 15:20:32 +01:00
Marc 'risson' Schmitt
f41f501702 finish metrics
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-24 15:20:09 +01:00
Marc 'risson' Schmitt
08685a574a better metrics
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-23 18:40:32 +01:00
Marc 'risson' Schmitt
15377f5154 better arbiter again
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-23 15:30:33 +01:00
Marc 'risson' Schmitt
52da505aab Merge branch 'main' into rust-server 2026-02-23 13:47:34 +01:00
Marc 'risson' Schmitt
d8a2a069aa add tests for extractors
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-20 18:38:09 +01:00
Marc 'risson' Schmitt
fec9dcc2e7 better logging
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-20 16:54:32 +01:00
Marc 'risson' Schmitt
b644fa5a2c revert unintended change
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-20 13:30:00 +01:00
Marc 'risson' Schmitt
9a5d59533e remove rust worker, fixup main
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-20 13:28:02 +01:00
Marc 'risson' Schmitt
3c64570398 use arcswap instead of lock for config
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-20 13:19:54 +01:00
Marc 'risson' Schmitt
a735f6dcf3 support proxying websockets to core
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 18:44:16 +01:00
Marc 'risson' Schmitt
f33e7f13eb autoreload
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 17:47:21 +01:00
Marc 'risson' Schmitt
eee00fa29b fix returned headers
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 16:45:42 +01:00
Marc 'risson' Schmitt
5a95a14a8f wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 16:38:21 +01:00
Marc 'risson' Schmitt
7b46fac608 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 15:31:35 +01:00
Marc 'risson' Schmitt
bb488e1c2c some better lints
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 15:29:02 +01:00
Marc 'risson' Schmitt
138aa0e4e9 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-19 14:33:17 +01:00
Marc 'risson' Schmitt
e65cd2999f tls headers
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 19:36:53 +01:00
Marc 'risson' Schmitt
490790c272 cleanup forwarding
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 19:03:30 +01:00
Marc 'risson' Schmitt
b640b42dbb wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 18:16:42 +01:00
Marc 'risson' Schmitt
1371465ebe fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 18:13:48 +01:00
Marc 'risson' Schmitt
c623b96dc2 some more cleanup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 18:13:18 +01:00
Marc 'risson' Schmitt
43fe1918db cleanup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 17:38:59 +01:00
Marc 'risson' Schmitt
3e2489834d remove cargo.lock for docsmg
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 15:56:28 +01:00
Marc 'risson' Schmitt
7ba86b7de3 introduce the arbiter, treewide fmt
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 15:55:45 +01:00
Marc 'risson' Schmitt
85ef3cda04 bring docsmg up to standards
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 15:18:48 +01:00
Marc 'risson' Schmitt
62911536bf fix tests
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 14:14:29 +01:00
Marc 'risson' Schmitt
1a27971399 move everything to a single crate
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-18 14:13:51 +01:00
Marc 'risson' Schmitt
7a0e946bb5 start moving to a single crate
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-17 18:45:49 +01:00
Marc 'risson' Schmitt
428ccc2c14 fix metrics endpoint
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-17 18:27:17 +01:00
Marc 'risson' Schmitt
0b706d5830 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-17 13:44:03 +01:00
Marc 'risson' Schmitt
b9f4a1aed7 Merge branch 'main' into rust-server 2026-02-16 14:07:56 +01:00
Marc 'risson' Schmitt
d2cb45aadf start cleanup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-11 15:15:43 +01:00
Marc 'risson' Schmitt
de12748f25 Merge branch 'main' into rust-server 2026-02-11 14:41:18 +01:00
Marc 'risson' Schmitt
f8f39b8edc start on metrics
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-09 19:25:06 +01:00
Marc 'risson' Schmitt
986385a951 initialize python globally when needed
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-09 18:45:20 +01:00
Marc 'risson' Schmitt
129ed95cf0 Merge branch 'main' into rust-server 2026-02-09 18:39:57 +01:00
Marc 'risson' Schmitt
dc0d535fcc small improvements to storage token checks
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-09 18:39:40 +01:00
Marc 'risson' Schmitt
5c0e23a78f features for which process to build
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-09 18:07:39 +01:00
Marc 'risson' Schmitt
b4bf082864 Merge branch 'main' into rust-server 2026-02-09 14:37:11 +01:00
Dominic R
2f00983c29 static file handling 2026-02-08 11:41:30 -05:00
Dominic R
af93a1e230 rev tls_state to what it was before 2026-02-08 09:47:32 -05:00
Dominic R
dbb3898621 fix client error 2026-02-08 09:44:47 -05:00
Dominic R
a668ddcaf5 make it work on macos 2026-02-08 09:32:38 -05:00
Marc 'risson' Schmitt
051aea6f99 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-06 14:47:51 +01:00
Marc 'risson' Schmitt
b8104ec156 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-05 17:30:12 +01:00
Marc 'risson' Schmitt
e59970e6ab fmt
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-04 19:04:18 +01:00
Marc 'risson' Schmitt
0b50b0aa13 actually use the proxy protocol acceptor
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-04 19:03:45 +01:00
Marc 'risson' Schmitt
7b9b1c2c70 proxy protocol and tls extractors finally
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-04 18:55:41 +01:00
Marc 'risson' Schmitt
1e1cdffb33 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-04 14:29:20 +01:00
Marc 'risson' Schmitt
8ad572ba35 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-02 19:14:14 +01:00
Marc 'risson' Schmitt
8a5b8ad047 custom tls acceptor, wip proxy protocol acceptor
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-02 18:43:27 +01:00
Marc 'risson' Schmitt
907a4ce478 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-02-02 17:07:24 +01:00
Marc 'risson' Schmitt
a26254df02 compression and loading page
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-30 17:58:59 +01:00
Marc 'risson' Schmitt
bf9679dcb5 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-30 17:01:10 +01:00
Marc 'risson' Schmitt
71ee2f6c66 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-30 17:00:49 +01:00
Marc 'risson' Schmitt
90fb12a804 proxying works correctly now
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-30 16:09:41 +01:00
Marc 'risson' Schmitt
e271a8a0af static files
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-29 15:43:36 +01:00
Marc 'risson' Schmitt
6100fd7800 remove sea-orm
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-29 13:59:11 +01:00
Marc 'risson' Schmitt
b78d62f550 add console-subscriber for tokio-console debugging
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 15:56:14 +01:00
Marc 'risson' Schmitt
21eb1bb7d0 fix gunicorn shutdown detection
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 15:56:01 +01:00
Marc 'risson' Schmitt
e4445a44c4 db
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 15:22:41 +01:00
Marc 'risson' Schmitt
6fecbb41ca start on db
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 01:53:37 +01:00
Marc 'risson' Schmitt
4a840796bf Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 01:36:55 +01:00
Marc 'risson' Schmitt
cc7f190735 config reloading
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-28 01:33:25 +01:00
Marc 'risson' Schmitt
c4962f86dd wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-27 20:17:32 +01:00
Marc 'risson' Schmitt
ad672338e0 Merge branch 'main' into rust-server 2026-01-27 15:13:32 +01:00
Marc 'risson' Schmitt
fadf344955 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-26 14:25:40 +01:00
Marc 'risson' Schmitt
8c58873a3a wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-22 17:19:21 +01:00
Marc 'risson' Schmitt
ac7dd69be2 fixup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-21 17:57:24 +01:00
Marc 'risson' Schmitt
f01ab7ccb2 we proxying
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-21 17:53:56 +01:00
Marc 'risson' Schmitt
13f7ac6eca wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-20 17:57:09 +01:00
Marc 'risson' Schmitt
24202f9a3f wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-20 15:26:56 +01:00
Marc 'risson' Schmitt
5a72130576 cleanup
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-19 17:52:35 +01:00
Marc 'risson' Schmitt
fe5d24004e Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-19 16:30:12 +01:00
Marc 'risson' Schmitt
dd7c13c5bd wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-12 18:12:52 +01:00
Marc 'risson' Schmitt
32de1ab6c6 Merge branch 'main' into rust-server
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2026-01-12 14:07:32 +01:00
Marc 'risson' Schmitt
6e4384d672 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-17 14:26:40 +01:00
Marc 'risson' Schmitt
79f7759d4b Merge branch 'main' into rust-server 2025-11-14 14:48:26 +01:00
Marc 'risson' Schmitt
0ca41cb184 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-13 19:01:40 +01:00
Marc 'risson' Schmitt
f8e5c895d6 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-13 18:58:30 +01:00
Marc 'risson' Schmitt
2ba8991a3b wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-13 18:30:57 +01:00
Marc 'risson' Schmitt
19b36d2e0d wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-13 14:08:54 +01:00
Marc 'risson' Schmitt
fb802a53bc Merge branch 'main' into rust-server 2025-11-13 03:28:17 +01:00
Marc 'risson' Schmitt
2f6465d5a0 Merge branch 'main' into rust-server 2025-11-12 15:16:35 +01:00
Marc 'risson' Schmitt
c5437d2b0b wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-07 18:46:21 +01:00
Marc 'risson' Schmitt
8e2e90a87f wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-06 19:14:07 +01:00
Marc 'risson' Schmitt
4deb3d45cf wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-06 18:43:33 +01:00
Marc 'risson' Schmitt
b61bb3cc17 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-05 16:07:11 +01:00
Marc 'risson' Schmitt
af3332df9f wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-04 19:30:43 +01:00
Marc 'risson' Schmitt
0849df7478 wip
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
2025-11-04 17:53:54 +01:00
63 changed files with 10810 additions and 300 deletions

2
.cargo/config.toml Normal file
View File

@@ -0,0 +1,2 @@
[build]
rustflags = ["--cfg", "tokio_unstable"]

View File

@@ -1,5 +1,17 @@
[licenses]
allow = ["Apache-2.0", "MIT", "MPL-2.0", "Unicode-3.0"]
allow = [
"Apache-2.0 WITH LLVM-exception",
"Apache-2.0",
"BSD-3-Clause",
"CC0-1.0",
"CDLA-Permissive-2.0",
"ISC",
"MIT",
"MPL-2.0",
"OpenSSL",
"Unicode-3.0",
"Zlib",
]
[licenses.private]
ignore = true

View File

@@ -144,6 +144,7 @@ jobs:
CI_TEST_SEED: ${{ needs.test-make-seed.outputs.seed }}
CI_RUN_ID: ${{ matrix.run_id }}
CI_TOTAL_RUNS: "5"
PROMETHEUS_MULTIPROC_DIR: /tmp
run: |
uv run make ci-test
- uses: ./.github/actions/test-results
@@ -173,6 +174,7 @@ jobs:
CI_TEST_SEED: ${{ needs.test-make-seed.outputs.seed }}
CI_RUN_ID: ${{ matrix.run_id }}
CI_TOTAL_RUNS: "5"
PROMETHEUS_MULTIPROC_DIR: /tmp
run: |
uv run make ci-test
- uses: ./.github/actions/test-results
@@ -189,6 +191,8 @@ jobs:
- name: Create k8s Kind Cluster
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # v1.14.0
- name: run integration
env:
PROMETHEUS_MULTIPROC_DIR: /tmp
run: |
uv run coverage run manage.py test tests/integration
uv run coverage xml
@@ -245,6 +249,8 @@ jobs:
npm run build
npm run build:sfe
- name: run e2e
env:
PROMETHEUS_MULTIPROC_DIR: /tmp
run: |
uv run coverage run manage.py test ${{ matrix.job.glob }}
uv run coverage xml
@@ -288,6 +294,8 @@ jobs:
npm run build
npm run build:sfe
- name: run conformance
env:
PROMETHEUS_MULTIPROC_DIR: /tmp
run: |
uv run coverage run manage.py test ${{ matrix.job.glob }}
uv run coverage xml

4916
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
[workspace]
members = ["website/scripts/docsmg"]
members = [".", "website/scripts/docsmg"]
resolver = "3"
[workspace.package]
@@ -12,11 +12,101 @@ license-file = "LICENSE"
publish = false
[workspace.dependencies]
arc-swap = "1.8.2"
argh = "0.1.17"
async-trait = "0.1.89"
aws-lc-rs = { version = "1.16.1", features = ["fips"] }
axum = { version = "0.8.8", features = ["http2", "macros", "ws"] }
axum-server = { version = "0.8.0", features = ["tls-rustls-no-provider"] }
bytes = "1.11.1"
chrono = "0.4.44"
clap = { version = "4.5.59", features = ["derive", "env"] }
client-ip = { version = "0.2.1", features = ["forwarded-header"] }
color-eyre = "0.6.5"
colored = "3.1.1"
config = { version = "0.15.19", default-features = false, features = [
"yaml",
"async",
] }
console-subscriber = "0.5.0"
dotenvy = "0.15.7"
durstr = "0.4.0"
eyre = "0.6.12"
forwarded-header-value = "0.1.1"
futures = "0.3.32"
glob = "0.3.3"
http-body-util = "0.1.3"
hyper = "1.8.1"
hyper-unix-socket = "0.3.0"
hyper-util = "0.1.20"
ipnet = { version = "2.12.0", features = ["serde"] }
# See https://github.com/mladedav/json-subscriber/pull/23
json-subscriber = { git = "https://github.com/rissson/json-subscriber.git", rev = "950ad7cb887a0a14fd5cb8afb8e76db1f456c032" }
jsonwebtoken = { version = "10.3.0", default-features = false, features = [
"aws_lc_rs",
] }
metrics = "0.24.3"
metrics-exporter-prometheus = { version = "0.18.1", default-features = false }
nix = { version = "0.31.2", features = ["hostname", "signal"] }
notify = "8.2.0"
pem = "3.0.6"
pin-project-lite = "0.2.17"
pyo3 = "0.28.2"
percent-encoding = "2.3.2"
rcgen = { version = "0.14.7", default-features = false, features = [
"aws_lc_rs",
"fips",
] }
regex = "1.12.3"
rustls = { version = "0.23.37", features = ["fips"] }
sentry = { version = "0.47.0", default-features = false, features = [
"backtrace",
"contexts",
"debug-images",
"panic",
"rustls",
"reqwest",
"tower",
"tracing",
] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
sqlx = { version = "0.8.6", default-features = false, features = [
"runtime-tokio",
"tls-rustls-aws-lc-rs",
"postgres",
"derive",
"macros",
"uuid",
"chrono",
"ipnet",
"json",
] }
time = "0.3.47"
thiserror = "2.0.18"
tokio = { version = "1.50.0", features = ["full"] }
tokio-rustls = "0.26.4"
tokio-tungstenite = "0.28.0"
tokio-util = "0.7.18"
tower = "0.5.3"
tower-http = { version = "0.6.8", features = [
"compression-br",
"compression-deflate",
"compression-gzip",
"compression-zstd",
"fs",
"timeout",
] }
tower-service = "0.3.3"
tracing = "0.1.44"
tracing-error = "0.2.1"
tracing-subscriber = { version = "0.3.22", features = [
"env-filter",
"json",
"tracing-log",
] }
url = "2.5.8"
uuid = { version = "1.22.0", features = ["v4"] }
[profile.dev.package.backtrace]
opt-level = 3
@@ -60,10 +150,14 @@ perf = { priority = -1, level = "warn" }
style = { priority = -1, level = "warn" }
suspicious = { priority = -1, level = "warn" }
### and disable the ones we don't want
### cargo group
multiple_crate_versions = "allow"
### pedantic group
redundant_closure_for_method_calls = "allow"
struct_field_names = "allow"
too_many_lines = "allow"
### nursery
missing_const_for_fn = "allow"
redundant_pub_crate = "allow"
option_if_let_else = "allow"
### restriction group
@@ -78,7 +172,6 @@ create_dir = "warn"
dbg_macro = "warn"
default_numeric_fallback = "warn"
disallowed_script_idents = "warn"
doc_paragraphs_missing_punctuation = "warn"
empty_drop = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
@@ -131,3 +224,73 @@ unused_trait_names = "warn"
unwrap_in_result = "warn"
unwrap_used = "warn"
verbose_file_reads = "warn"
[package]
name = "authentik"
version = "2026.5.0-rc1"
authors.workspace = true
edition.workspace = true
readme.workspace = true
homepage.workspace = true
repository.workspace = true
license-file.workspace = true
publish.workspace = true
[features]
default = ["core", "proxy"]
proxy = []
core = ["proxy", "dep:sqlx", "dep:pyo3"]
[dependencies]
arc-swap.workspace = true
argh.workspace = true
async-trait.workspace = true
aws-lc-rs.workspace = true
axum-server.workspace = true
axum.workspace = true
client-ip.workspace = true
color-eyre.workspace = true
config.workspace = true
console-subscriber.workspace = true
durstr.workspace = true
eyre.workspace = true
forwarded-header-value.workspace = true
futures.workspace = true
glob.workspace = true
http-body-util.workspace = true
hyper-unix-socket.workspace = true
hyper-util.workspace = true
hyper.workspace = true
ipnet.workspace = true
json-subscriber.workspace = true
jsonwebtoken.workspace = true
metrics.workspace = true
metrics-exporter-prometheus.workspace = true
nix.workspace = true
notify.workspace = true
pem.workspace = true
percent-encoding.workspace = true
pin-project-lite.workspace = true
pyo3 = { workspace = true, optional = true }
rcgen.workspace = true
rustls.workspace = true
sentry.workspace = true
serde.workspace = true
serde_json.workspace = true
sqlx = { workspace = true, optional = true }
thiserror.workspace = true
time.workspace = true
tokio-rustls.workspace = true
tokio-tungstenite.workspace = true
tokio-util.workspace = true
tokio.workspace = true
tower-http.workspace = true
tower.workspace = true
tracing-error.workspace = true
tracing-subscriber.workspace = true
tracing.workspace = true
url.workspace = true
uuid.workspace = true
[lints]
workspace = true

View File

@@ -84,7 +84,7 @@ test: ## Run the server tests and produce a coverage report (locally)
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
$(UV) run black $(PY_SOURCES)
$(UV) run ruff check --fix $(PY_SOURCES)
$(CARGO) +nightly fmt --all -- --config-path .config/rustfmt.toml
$(CARGO) +nightly fmt --all -- --config-path .cargo/rustfmt.toml
lint-spellcheck: ## Reports spelling errors.
npm run lint:spellcheck
@@ -110,12 +110,24 @@ i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that requir
aws-cfn:
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
run-server: ## Run the main authentik server process
run: ## Run the authentik server and worker, without auto reloading
$(UV) run ak allinone
run-watch: ## Run the authentik server and worker, with auto reloading
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak allinone
run-server: ## Run the authentik server, without auto reloading
$(UV) run ak server
run-worker: ## Run the main authentik worker process
run-server-watch: ## Run the authentik server, with auto reloading
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak server
run-worker: ## Run the authentik worker, without auto reloading
$(UV) run ak worker
run-worker-watch: ## Run the authentik worker, with auto reloading
$(UV) run watchexec --on-busy-update=restart --stop-signal=SIGINT --exts py,rs --no-meta --notify -- ak worker
core-i18n-extract:
$(UV) run ak makemessages \
--add-location file \
@@ -154,7 +166,7 @@ ifndef version
$(error Usage: make bump version=20xx.xx.xx )
endif
$(eval current_version := $(shell cat ${PWD}/internal/constants/VERSION))
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml
$(SED_INPLACE) 's/^version = ".*"/version = "$(version)"/' ${PWD}/pyproject.toml ${PWD}/Cargo.toml
$(SED_INPLACE) 's/^VERSION = ".*"/VERSION = "$(version)"/' ${PWD}/authentik/__init__.py
$(MAKE) gen-build gen-compose aws-cfn
$(SED_INPLACE) "s/\"${current_version}\"/\"$(version)\"/" ${PWD}/package.json ${PWD}/package-lock.json ${PWD}/web/package.json ${PWD}/web/package-lock.json
@@ -359,13 +371,13 @@ ci-lint-pending-migrations: ci--meta-debug
$(UV) run ak makemigrations --check
ci-lint-cargo-deny: ci--meta-debug
$(CARGO) deny --locked --workspace check --config .config/deny.toml
$(CARGO) deny --locked --workspace check --config .cargo/deny.toml
ci-lint-cargo-machete: ci--meta-debug
$(CARGO) machete
ci-lint-rustfmt: ci--meta-debug
$(CARGO) +nightly fmt --all --check -- --config-path .config/rustfmt.toml
$(CARGO) +nightly fmt --all --check -- --config-path .cargo/rustfmt.toml
ci-lint-clippy: ci--meta-debug
$(CARGO) clippy -- -D warnings

View File

@@ -92,6 +92,7 @@ class FileBackend(ManageableBackend):
"nbf": now() - timedelta(seconds=15),
},
key=sha256(f"{settings.SECRET_KEY}:{self.usage}".encode()).hexdigest(),
# Must match crates/authentik-server/src/static.rs
algorithm="HS256",
)
url = f"{prefix}/files/{path}?token={token}"

View File

@@ -1,7 +1,5 @@
"""Apply blueprint from commandline"""
from sys import exit as sys_exit
from django.core.management.base import BaseCommand, no_translations
from structlog.stdlib import get_logger
@@ -28,7 +26,7 @@ class Command(BaseCommand):
self.stderr.write("Blueprint invalid")
for log in logs:
self.stderr.write(f"\t{log.logger}: {log.event}: {log.attributes}")
sys_exit(1)
raise RuntimeError("Blueprint invalid")
importer.apply()
def add_arguments(self, parser):

View File

@@ -342,10 +342,10 @@ def django_db_config(config: ConfigLoader | None = None) -> dict:
"default": {
"ENGINE": "psqlextra.backend",
"HOST": config.get("postgresql.host"),
"NAME": config.get("postgresql.name"),
"PORT": config.get("postgresql.port"),
"USER": config.get("postgresql.user"),
"PASSWORD": config.get("postgresql.password"),
"PORT": config.get("postgresql.port"),
"NAME": config.get("postgresql.name"),
"OPTIONS": {
"sslmode": config.get("postgresql.sslmode"),
"sslrootcert": config.get("postgresql.sslrootcert"),
@@ -423,4 +423,5 @@ if __name__ == "__main__":
if len(argv) < 2: # noqa: PLR2004
print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder))
else:
print(CONFIG.get(argv[-1]))
for arg in argv[1:]:
print(CONFIG.get(arg))

View File

@@ -17,11 +17,13 @@
postgresql:
host: localhost
name: authentik
user: authentik
port: 5432
user: authentik
password: "env://POSTGRES_PASSWORD"
name: authentik
sslmode: disable
conn_max_age: 60
conn_health_checks: false
use_pool: False
test:
name: test_authentik
@@ -72,6 +74,19 @@ log_level: info
log:
http_headers:
- User-Agent
rust_log:
"console_subscriber": info
"h2": info
"hyper_util": warn
"mio": info
"notify": info
"reqwest": info
"runtime": info
"rustls": info
"sqlx": info
"sqlx_postgres": info
"tokio": info
"tungstenite": info
sessions:
unauthenticated_age: days=1
@@ -143,8 +158,7 @@ tenants:
blueprints_dir: /blueprints
web:
# No default here as it's set dynamically
# workers: 2
workers: 2
threads: 4
path: /
timeout_http_read_header: 5s

View File

@@ -41,7 +41,7 @@ def structlog_configure():
add_process_id,
add_tenant_information,
structlog.stdlib.PositionalArgumentsFormatter(),
structlog.processors.TimeStamper(fmt="iso", utc=False),
structlog.processors.TimeStamper(fmt="iso", utc=True),
structlog.processors.StackInfoRenderer(),
structlog.processors.ExceptionRenderer(
structlog.tracebacks.ExceptionDictTransformer(show_locals=CONFIG.get_bool("debug"))

View File

@@ -339,6 +339,9 @@ class LoggingMiddleware:
def log(self, request: HttpRequest, status_code: int, runtime: int, **kwargs):
"""Log request"""
# Those are logged by the server above
if request.path in ("/-/metrics/", "/-/health/ready/"):
return
for header in self.headers_to_log:
header_value = request.headers.get(header)
if not header_value:

View File

@@ -1,37 +1,21 @@
"""Metrics view"""
from hmac import compare_digest
from pathlib import Path
from tempfile import gettempdir
from django.conf import settings
from django.db import connections
from django.db.utils import OperationalError
from django.dispatch import Signal
from django.http import HttpRequest, HttpResponse
from django.views import View
from django_prometheus.exports import ExportToDjangoView
monitoring_set = Signal()
class MetricsView(View):
"""Wrapper around ExportToDjangoView with authentication, accessed by the authentik router"""
def __init__(self, **kwargs):
_tmp = Path(gettempdir())
with open(_tmp / "authentik-core-metrics.key") as _f:
self.monitoring_key = _f.read()
"""View for metrics monitoring_set signal, accessed by the authentik router"""
def get(self, request: HttpRequest) -> HttpResponse:
"""Check for HTTP-Basic auth"""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
auth_type, _, given_credentials = auth_header.partition(" ")
authed = auth_type == "Bearer" and compare_digest(given_credentials, self.monitoring_key)
if not authed and not settings.DEBUG:
return HttpResponse(status=401)
monitoring_set.send_robust(self)
return ExportToDjangoView(request)
return HttpResponse(status=204)
class LiveView(View):

View File

@@ -440,8 +440,6 @@ DRAMATIQ = {
("authentik.tasks.middleware.TaskLogMiddleware", {}),
("authentik.tasks.middleware.LoggingMiddleware", {}),
("authentik.tasks.middleware.DescriptionMiddleware", {}),
("authentik.tasks.middleware.WorkerHealthcheckMiddleware", {}),
("authentik.tasks.middleware.WorkerStatusMiddleware", {}),
(
"authentik.tasks.middleware.MetricsMiddleware",
{

View File

@@ -14,12 +14,12 @@ class TestRoot(TransactionTestCase):
def setUp(self):
_tmp = Path(gettempdir())
self.token = token_urlsafe(32)
with open(_tmp / "authentik-core-metrics.key", "w") as _f:
with open(_tmp / "authentik-metrics-gunicorn.key", "w") as _f:
_f.write(self.token)
def tearDown(self):
_tmp = Path(gettempdir())
(_tmp / "authentik-core-metrics.key").unlink()
(_tmp / "authentik-metrics-gunicorn.key").unlink()
def test_monitoring_error(self):
"""Test monitoring without any credentials"""

View File

@@ -1,5 +1,6 @@
import pglock
from django.utils.timezone import now, timedelta
from datetime import timedelta
from django.utils.timezone import now
from drf_spectacular.utils import extend_schema, inline_serializer
from packaging.version import parse
from rest_framework.fields import BooleanField, CharField
@@ -31,18 +32,13 @@ class WorkerView(APIView):
def get(self, request: Request) -> Response:
response = []
our_version = parse(authentik_full_version())
for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(minutes=2)):
lock_id = f"goauthentik.io/worker/status/{status.pk}"
with pglock.advisory(lock_id, timeout=0, side_effect=pglock.Return) as acquired:
# The worker doesn't hold the lock, it isn't running
if acquired:
continue
version_matching = parse(status.version) == our_version
response.append(
{
"worker_id": f"{status.pk}@{status.hostname}",
"version": status.version,
"version_matching": version_matching,
}
)
for status in WorkerStatus.objects.filter(last_seen__gt=now() - timedelta(seconds=45)):
version_matching = parse(status.version) == our_version
response.append(
{
"worker_id": f"{status.pk}@{status.hostname}",
"version": status.version,
"version_matching": version_matching,
}
)
return Response(response)

View File

@@ -1,42 +1,23 @@
import socket
from collections.abc import Callable
from http.server import BaseHTTPRequestHandler
from threading import Event as TEvent
from threading import Thread, current_thread
from typing import Any, cast
import pglock
from django.db import OperationalError, connections, transaction
from django.utils.timezone import now
from django.db import OperationalError
from django_dramatiq_postgres.middleware import (
CurrentTask as BaseCurrentTask,
)
from django_dramatiq_postgres.middleware import (
HTTPServer,
HTTPServerThread,
)
from django_dramatiq_postgres.middleware import (
MetricsMiddleware as BaseMetricsMiddleware,
)
from django_dramatiq_postgres.middleware import (
_MetricsHandler as BaseMetricsHandler,
)
from dramatiq import Worker
from dramatiq.broker import Broker
from dramatiq.message import Message
from dramatiq.middleware import Middleware
from psycopg.errors import Error
from setproctitle import setthreadtitle
from structlog.stdlib import get_logger
from authentik import authentik_full_version
from authentik.events.models import Event, EventAction
from authentik.lib.config import CONFIG
from authentik.lib.sentry import should_ignore_exception
from authentik.lib.utils.reflection import class_to_path
from authentik.root.monitoring import monitoring_set
from authentik.root.signals import post_startup, pre_startup, startup
from authentik.tasks.models import Task, TaskLog, TaskStatus, WorkerStatus
from authentik.tasks.models import Task, TaskLog, TaskStatus
from authentik.tenants.models import Tenant
from authentik.tenants.utils import get_current_tenant
@@ -193,154 +174,15 @@ class DescriptionMiddleware(Middleware):
return {"description"}
class _healthcheck_handler(BaseHTTPRequestHandler):
def log_request(self, code="-", size="-"):
HEALTHCHECK_LOGGER.info(
self.path,
method=self.command,
status=code,
)
def log_error(self, format, *args):
HEALTHCHECK_LOGGER.warning(format, *args)
def do_HEAD(self):
try:
for db_conn in connections.all():
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
self.send_response(200)
except DB_ERRORS: # pragma: no cover
self.send_response(503)
self.send_header("Content-Type", "text/plain; charset=utf-8")
self.send_header("Content-Length", "0")
self.end_headers()
do_GET = do_HEAD
class WorkerHealthcheckMiddleware(Middleware):
thread: HTTPServerThread | None
def __init__(self):
listen = CONFIG.get("listen.http", ["[::]:9000"])
if isinstance(listen, str):
listen = listen.split(",")
host, _, port = listen[0].rpartition(":")
try:
port = int(port)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
self.host, self.port = host, port
def after_worker_boot(self, broker: Broker, worker: Worker):
self.thread = HTTPServerThread(
target=WorkerHealthcheckMiddleware.run, args=(self.host, self.port)
)
self.thread.start()
def before_worker_shutdown(self, broker: Broker, worker: Worker):
server = self.thread.server
if server:
server.shutdown()
LOGGER.debug("Stopping WorkerHealthcheckMiddleware")
self.thread.join()
@staticmethod
def run(addr: str, port: int):
setthreadtitle("authentik Worker Healthcheck server")
try:
server = HTTPServer((addr, port), _healthcheck_handler)
thread = cast(HTTPServerThread, current_thread())
thread.server = server
server.serve_forever()
except OSError as exc:
get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning(
"Port is already in use, not starting healthcheck server",
exc=exc,
)
class WorkerStatusMiddleware(Middleware):
thread: Thread | None
thread_event: TEvent | None
def after_worker_boot(self, broker: Broker, worker: Worker):
self.thread_event = TEvent()
self.thread = Thread(target=WorkerStatusMiddleware.run, args=(self.thread_event,))
self.thread.start()
def before_worker_shutdown(self, broker: Broker, worker: Worker):
self.thread_event.set()
LOGGER.debug("Stopping WorkerStatusMiddleware")
self.thread.join()
@staticmethod
def run(event: TEvent):
setthreadtitle("authentik Worker status")
with transaction.atomic():
hostname = socket.gethostname()
WorkerStatus.objects.filter(hostname=hostname).delete()
status, _ = WorkerStatus.objects.update_or_create(
hostname=hostname,
version=authentik_full_version(),
)
while not event.is_set():
try:
WorkerStatusMiddleware.keep(event, status)
except DB_ERRORS: # pragma: no cover
event.wait(10)
try:
connections.close_all()
except DB_ERRORS:
pass
@staticmethod
def keep(event: TEvent, status: WorkerStatus):
lock_id = f"goauthentik.io/worker/status/{status.pk}"
with pglock.advisory(lock_id, side_effect=pglock.Raise):
while not event.is_set():
status.refresh_from_db()
old_last_seen = status.last_seen
status.last_seen = now()
if old_last_seen != status.last_seen:
status.save(update_fields=("last_seen",))
event.wait(30)
class _MetricsHandler(BaseMetricsHandler):
def do_GET(self) -> None:
monitoring_set.send_robust(self)
return super().do_GET()
class MetricsMiddleware(BaseMetricsMiddleware):
thread: HTTPServerThread | None
handler_class = _MetricsHandler
@property
def forks(self) -> list[Callable[[], None]]:
def forks(self):
return []
def after_worker_boot(self, broker: Broker, worker: Worker):
listen = CONFIG.get("listen.metrics", ["[::]:9300"])
if isinstance(listen, str):
listen = listen.split(",")
addr, _, port = listen[0].rpartition(":")
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
from prometheus_client import values
from prometheus_client.values import MultiProcessValue
try:
port = int(port)
except ValueError:
LOGGER.error(f"Invalid port entered: {port}")
self.thread = HTTPServerThread(target=MetricsMiddleware.run, args=(addr, port))
self.thread.start()
values.ValueClass = MultiProcessValue(lambda: worker.worker_id)
def before_worker_shutdown(self, broker: Broker, worker: Worker):
server = self.thread.server
if server:
server.shutdown()
LOGGER.debug("Stopping MetricsMiddleware")
self.thread.join()
return super().before_worker_boot(broker, worker)

View File

@@ -1,4 +1,6 @@
from django.utils.timezone import now, timedelta
from datetime import timedelta
from django.utils.timezone import now
from django.utils.translation import gettext_lazy as _
from dramatiq import actor

View File

@@ -10,7 +10,6 @@ from dramatiq.results.middleware import Results
from dramatiq.worker import Worker, _ConsumerThread, _WorkerThread
from authentik.tasks.broker import PostgresBroker
from authentik.tasks.middleware import WorkerHealthcheckMiddleware
TESTING_QUEUE = "testing"
@@ -18,6 +17,7 @@ TESTING_QUEUE = "testing"
class TestWorker(Worker):
def __init__(self, broker: Broker):
super().__init__(broker=broker)
self.worker_id = 1000
self.work_queue = PriorityQueue()
self.consumers = {
TESTING_QUEUE: _ConsumerThread(
@@ -82,8 +82,6 @@ def use_test_broker():
middleware: Middleware = import_string(middleware_class)(
**middleware_kwargs,
)
if isinstance(middleware, WorkerHealthcheckMiddleware):
middleware.port = 9102
if isinstance(middleware, Retries):
middleware.max_retries = 0
if isinstance(middleware, Results):

View File

@@ -1,10 +1,6 @@
#!/usr/bin/env -S bash
set -e -o pipefail
MODE_FILE="${TMPDIR}/authentik-mode"
#!/usr/bin/env bash
if [[ -z "${PROMETHEUS_MULTIPROC_DIR}" ]]; then
export PROMETHEUS_MULTIPROC_DIR="${TMPDIR:-/tmp}/authentik_prometheus_tmp"
fi
set -e -o pipefail
function log {
printf '{"event": "%s", "level": "info", "logger": "bootstrap"}\n' "$@" >&2
@@ -15,10 +11,18 @@ function wait_for_db {
log "Bootstrap completed"
}
function check_if_root {
function run_authentik {
if [[ -x "$(command -v authentik)" ]]; then
echo authentik "$@"
else
echo cargo run -- "$@"
fi
}
function check_if_root_and_run {
if [[ $EUID -ne 0 ]]; then
log "Not running as root, disabling permission fixes"
exec $1
exec $(run_authentik "$@")
return
fi
SOCKET="/var/run/docker.sock"
@@ -26,36 +30,19 @@ function check_if_root {
if [[ -e "$SOCKET" ]]; then
# Get group ID of the docker socket, so we can create a matching group and
# add ourselves to it
DOCKER_GID=$(stat -c '%g' $SOCKET)
DOCKER_GID="$(stat -c "%g" "${SOCKET}")"
# Ensure group for the id exists
getent group $DOCKER_GID || groupadd -f -g $DOCKER_GID docker
usermod -a -G $DOCKER_GID authentik
getent group "${DOCKER_GID}" || groupadd -f -g "${DOCKER_GID}" docker
usermod -a -G "${DOCKER_GID}" authentik
# since the name of the group might not be docker, we need to lookup the group id
GROUP_NAME=$(getent group $DOCKER_GID | sed 's/:/\n/g' | head -1)
GROUP_NAME=$(getent group "${DOCKER_GID}" | sed 's/:/\n/g' | head -1)
GROUP="authentik:${GROUP_NAME}"
fi
# Fix permissions of certs and media
chown -R authentik:authentik /data /certs "${PROMETHEUS_MULTIPROC_DIR}"
chmod ug+rwx /data
chmod ug+rx /certs
exec chpst -u authentik:$GROUP env HOME=/authentik $1
}
function run_authentik {
if [[ -x "$(command -v authentik)" ]]; then
exec authentik $@
else
exec go run -v ./cmd/server/ $@
fi
}
function set_mode {
echo $1 >$MODE_FILE
trap cleanup EXIT
}
function cleanup {
rm -f ${MODE_FILE}
exec chpst -u authentik:"${GROUP}" env HOME=/authentik $(run_authentik "$@")
}
function prepare_debug {
@@ -72,38 +59,31 @@ function prepare_debug {
chown authentik:authentik /unittest.xml
}
if [[ -z "${PROMETHEUS_MULTIPROC_DIR}" ]]; then
export PROMETHEUS_MULTIPROC_DIR="${TMPDIR:-/tmp}/authentik_prometheus_tmp"
fi
mkdir -p "${PROMETHEUS_MULTIPROC_DIR}"
if [[ "$(python -m authentik.lib.config debugger 2>/dev/null)" == "True" ]]; then
prepare_debug
fi
if [[ "$1" == "server" ]]; then
set_mode "server"
run_authentik
elif [[ "$1" == "worker" ]]; then
set_mode "worker"
shift
# If we have bootstrap credentials set, run bootstrap tasks outside of main server
# sync, so that we can sure the first start actually has working bootstrap
# credentials
if [[ -n "${AUTHENTIK_BOOTSTRAP_PASSWORD}" || -n "${AUTHENTIK_BOOTSTRAP_TOKEN}" ]]; then
python -m manage apply_blueprint system/bootstrap.yaml || true
fi
check_if_root "python -m manage worker --pid-file ${TMPDIR}/authentik-worker.pid $@"
elif [[ "$1" == "bash" ]]; then
/bin/bash
elif [[ "$1" == "test-all" ]]; then
prepare_debug
chmod 777 /root
check_if_root "python -m manage test authentik"
elif [[ "$1" == "healthcheck" ]]; then
run_authentik healthcheck $(cat $MODE_FILE)
if [[ "$1" == "bash" ]]; then
exec /usr/bin/env -S bash "$@"
elif [[ "$1" == "dump_config" ]]; then
shift
exec python -m authentik.lib.config $@
shift 1
exec python -m authentik.lib.config "$@"
elif [[ "$1" == "debug" ]]; then
exec sleep infinity
elif [[ "$1" == "test-all" ]]; then
wait_for_db
prepare_debug
chmod 777 /root
check_if_root_and_run manage test authentik
elif [[ "$1" == "allinone" ]] || [[ "$1" == "server" ]] || [[ "$1" == "worker" ]] || [[ "$1" == "proxy" ]] || [[ "$1" == "manage" ]]; then
wait_for_db
check_if_root_and_run "$@"
else
wait_for_db
exec python -m manage "$@"
fi

View File

@@ -1,6 +1,8 @@
"""Gunicorn config"""
import os
import platform
import signal
from hashlib import sha512
from pathlib import Path
from tempfile import gettempdir
@@ -17,7 +19,6 @@ from authentik.lib.utils.reflection import get_env
from authentik.root.install_id import get_install_id_raw
from authentik.root.setup import setup
from lifecycle.migrate import run_migrations
from lifecycle.wait_for_db import wait_for_db
from lifecycle.worker import DjangoUvicornWorker
if TYPE_CHECKING:
@@ -28,16 +29,12 @@ if TYPE_CHECKING:
setup()
wait_for_db()
_tmp = Path(gettempdir())
worker_class = "lifecycle.worker.DjangoUvicornWorker"
worker_tmp_dir = str(_tmp.joinpath("authentik_gunicorn_tmp"))
os.makedirs(worker_tmp_dir, exist_ok=True)
bind = f"unix://{str(_tmp.joinpath('authentik-core.sock'))}"
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
preload_app = True
@@ -45,13 +42,28 @@ preload_app = True
max_requests = CONFIG.get_int("web.max_requests", 1000)
max_requests_jitter = CONFIG.get_int("web.max_requests_jitter", 50)
# Match the value in src/arbiter.rs for graceful shutdown
dirty_graceful_timeout = 30
logconfig_dict = get_logger_config()
default_workers = 2
workers = CONFIG.get_int("web.workers", default_workers)
workers = CONFIG.get_int("web.workers", 2)
threads = CONFIG.get_int("web.threads", 4)
# libpq can try Kerberos/GSS on macOS, which is not fork-safe in our Gunicorn worker model.
# Disable GSS negotiation for local/dev PostgreSQL connections on Darwin.
if platform.system() == "Darwin":
os.environ.setdefault("PGGSSENCMODE", "disable")
# Avoid macOS SystemConfiguration proxy lookups (_scproxy) in forked workers.
# urllib/requests may consult these APIs and can crash in child workers.
os.environ.setdefault("NO_PROXY", "*")
os.environ.setdefault("no_proxy", "*")
def when_ready(server: "Arbiter"): # noqa: UP037
# Notify rust process that we are ready
os.kill(os.getppid(), signal.SIGUSR1)
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): # noqa: UP037
"""Tell prometheus to use worker number instead of process ID for multiprocess"""

148
lifecycle/worker_process.py Normal file
View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
import os
import random
import signal
import sys
from http.server import BaseHTTPRequestHandler, HTTPServer
from socket import AF_UNIX
from threading import Event, Thread
from typing import Any
from dramatiq import Worker, get_broker
from structlog.stdlib import get_logger
from authentik.lib.config import CONFIG
LOGGER = get_logger()
INITIAL_WORKER_ID = 1000
class HttpHandler(BaseHTTPRequestHandler):
def check_db(self):
from django.db import connections
for db_conn in connections.all():
# Force connection reload
db_conn.connect()
_ = db_conn.cursor()
def do_GET(self):
if self.path == "/-/metrics/":
from authentik.root.monitoring import monitoring_set
monitoring_set.send_robust(self)
self.send_response(200)
self.end_headers()
elif self.path == "/-/health/ready/":
from django.db.utils import OperationalError
try:
self.check_db()
except OperationalError:
self.send_response(503)
self.send_response(200)
self.end_headers()
else:
self.send_response(200)
self.end_headers()
def log_message(self, format: str, *args: Any) -> None:
pass
class UnixSocketServer(HTTPServer):
address_family = AF_UNIX
def main(worker_id: int, socket_path: str | None):
shutdown = Event()
srv = None
def immediate_shutdown(signum, frame):
nonlocal srv
if srv is not None:
srv.shutdown()
if socket_path:
os.remove(socket_path)
sys.exit(0)
def graceful_shutdown(signum, frame):
nonlocal shutdown
shutdown.set()
signal.signal(signal.SIGHUP, immediate_shutdown)
signal.signal(signal.SIGINT, immediate_shutdown)
signal.signal(signal.SIGQUIT, immediate_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)
random.seed()
logger = LOGGER.bind(worker_id=worker_id)
logger.debug("Loading broker...")
broker = get_broker()
broker.emit_after("process_boot")
logger.debug("Starting worker threads...")
queues = None # all queues
worker = Worker(broker, queues=queues, worker_threads=CONFIG.get_int("worker.threads"))
worker.worker_id = worker_id
worker.start()
logger.info("Worker process is ready for action.")
if socket_path:
srv = UnixSocketServer(socket_path, HttpHandler)
Thread(target=srv.serve_forever).start()
# Notify rust process that we are ready
os.kill(os.getppid(), signal.SIGUSR2)
shutdown.wait()
logger.info("Shutting down worker...")
if srv is not None:
srv.shutdown()
if socket_path:
os.remove(socket_path)
# 5 secs if debug, 5 mins otherwise
worker.stop(timeout=5_000 if CONFIG.get_bool("debug") else 600_000)
broker.close()
logger.info("Worker shut down.")
if __name__ == "__main__":
if len(sys.argv) not in [2, 3]:
print("USAGE: worker_process <worker_id> [SOCKET_PATH]")
sys.exit(1)
worker_id = int(sys.argv[1])
socket_path = sys.argv[2] if len(sys.argv) == 3 else None # noqa: PLR2004
from authentik.root.setup import setup
setup()
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "authentik.root.settings")
import django
django.setup()
from django.core.management import execute_from_command_line
if socket_path:
from lifecycle.migrate import run_migrations
run_migrations()
if (
"AUTHENTIK_BOOTSTRAP_PASSWORD" in os.environ
or "AUTHENTIK_BOOTSTRAP_TOKEN" in os.environ
):
try:
execute_from_command_line(["", "apply_blueprint", "system/bootstrap.yaml"])
except Exception as exc: # noqa: BLE001
sys.stderr.write(f"Failed to apply bootstrap blueprint: {exc}")
main(worker_id, socket_path)

View File

@@ -8,12 +8,10 @@ from django.utils.autoreload import DJANGO_AUTORELOAD_ENV
from authentik.root.setup import setup
from lifecycle.migrate import run_migrations
from lifecycle.wait_for_db import wait_for_db
setup()
if __name__ == "__main__":
wait_for_db()
if (
len(sys.argv) > 1
# Explicitly only run migrate for server and worker

291
src/arbiter.rs Normal file
View File

@@ -0,0 +1,291 @@
//! Utilities to manage long running tasks, such as servers and watchers.
//!
//! Also manages signals sent to the main process.
use std::{net, os::unix, sync::Arc, time::Duration};
use axum_server::Handle;
use eyre::{Report, Result};
use tokio::{
signal::unix::{Signal, SignalKind, signal},
sync::{Mutex, broadcast, watch},
task::{JoinSet, join_set::Builder},
};
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use tracing::info;
/// All the signal streams we watch for. We don't create those directly in [`watch_signals`]
/// because that would prevent us from handling errors early.
struct SignalStreams {
hup: Signal,
int: Signal,
quit: Signal,
usr1: Signal,
usr2: Signal,
term: Signal,
}
impl SignalStreams {
fn new() -> Result<Self> {
Ok(Self {
hup: signal(SignalKind::hangup())?,
int: signal(SignalKind::interrupt())?,
quit: signal(SignalKind::quit())?,
usr1: signal(SignalKind::user_defined1())?,
usr2: signal(SignalKind::user_defined2())?,
term: signal(SignalKind::terminate())?,
})
}
}
/// Watch for incoming signals and either shutdown the application or dispatch them to receivers.
async fn watch_signals(
streams: SignalStreams,
arbiter: Arbiter,
_signals_rx: broadcast::Receiver<SignalKind>,
) -> Result<()> {
info!("starting signals watcher");
let SignalStreams {
mut hup,
mut int,
mut quit,
mut usr1,
mut usr2,
mut term,
} = streams;
loop {
tokio::select! {
_ = hup.recv() => {
info!("signal HUP received");
arbiter.do_fast_shutdown().await;
},
_ = int.recv() => {
info!("signal INT received");
arbiter.do_fast_shutdown().await;
},
_ = quit.recv() => {
info!("signal QUIT received");
arbiter.do_fast_shutdown().await;
},
_ = usr1.recv() => {
info!("signal URS1 received");
arbiter.signals_tx.send(SignalKind::user_defined1())?;
},
_ = usr2.recv() => {
info!("USR2 received.");
arbiter.signals_tx.send(SignalKind::user_defined2())?;
},
_ = term.recv() => {
info!("signal TERM received");
arbiter.do_graceful_shutdown().await;
},
() = arbiter.shutdown() => {
info!("stopping signals watcher");
return Ok(());
}
};
}
}
/// Manager for long running tasks, such as servers and watchers.
pub(crate) struct Tasks {
pub(crate) tasks: JoinSet<Result<()>>,
arbiter: Arbiter,
}
impl Tasks {
pub(crate) fn new() -> Result<Self> {
let mut tasks = JoinSet::new();
let arbiter = Arbiter::new(&mut tasks)?;
Ok(Self { tasks, arbiter })
}
/// Build a new task. See [`tokio::task::JoinSet::build_task`] for details.
pub(crate) fn build_task(&mut self) -> Builder<'_, Result<()>> {
self.tasks.build_task()
}
/// Get an [`Arbiter`]
pub(crate) fn arbiter(&self) -> Arbiter {
self.arbiter.clone()
}
pub(crate) async fn run(self) -> Vec<Report> {
let Self { mut tasks, arbiter } = self;
let mut errors = Vec::new();
if let Some(result) = tasks.join_next().await {
arbiter.do_graceful_shutdown().await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => {
arbiter.do_fast_shutdown().await;
errors.push(err);
}
Err(err) => {
arbiter.do_fast_shutdown().await;
errors.push(Report::new(err));
}
}
while let Some(result) = tasks.join_next().await {
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => errors.push(err),
Err(err) => errors.push(Report::new(err)),
}
}
}
errors
}
}
/// Manage shutdown state and several communication channels.
#[derive(Clone)]
pub(crate) struct Arbiter {
/// Token to shutdown the application immediately.
fast_shutdown: CancellationToken,
/// Token to shutdown the application gracefully.
graceful_shutdown: CancellationToken,
/// Token set when any shutdown is triggered.
shutdown: CancellationToken,
/// Axum handles to manage
net_handles: Arc<Mutex<Vec<Handle<net::SocketAddr>>>>,
unix_handles: Arc<Mutex<Vec<Handle<unix::net::SocketAddr>>>>,
/// Broadcaster of signals sent to the main process.
signals_tx: broadcast::Sender<SignalKind>,
/// Watcher of config change events
config_changed_tx: watch::Sender<()>,
_config_changed_rx: watch::Receiver<()>,
/// Token set when gunicorn is marked ready
gunicorn_ready: CancellationToken,
}
impl Arbiter {
fn new(tasks: &mut JoinSet<Result<()>>) -> Result<Self> {
let (signals_tx, signals_rx) = broadcast::channel(10);
let (config_changed_tx, config_changed_rx) = watch::channel(());
let arbiter = Self {
fast_shutdown: CancellationToken::new(),
graceful_shutdown: CancellationToken::new(),
shutdown: CancellationToken::new(),
// 5 is http, https, metrics and a bit of room
net_handles: Arc::new(Mutex::new(Vec::with_capacity(5))),
// 2 is http and metrics
unix_handles: Arc::new(Mutex::new(Vec::with_capacity(2))),
signals_tx,
config_changed_tx,
_config_changed_rx: config_changed_rx,
gunicorn_ready: CancellationToken::new(),
};
let streams = SignalStreams::new()?;
tasks
.build_task()
.name(&format!("{}::watch_signals", module_path!()))
.spawn(watch_signals(streams, arbiter.clone(), signals_rx))?;
Ok(arbiter)
}
pub(crate) async fn add_net_handle(&self, handle: Handle<net::SocketAddr>) {
self.net_handles.lock().await.push(handle);
}
pub(crate) async fn add_unix_handle(&self, handle: Handle<unix::net::SocketAddr>) {
self.unix_handles.lock().await.push(handle);
}
/// Future that will complete when the application needs to shutdown immediately.
pub(crate) fn fast_shutdown(&self) -> WaitForCancellationFuture<'_> {
self.fast_shutdown.cancelled()
}
/// Future that will complete when the application needs to shutdown gracefully.
pub(crate) fn graceful_shutdown(&self) -> WaitForCancellationFuture<'_> {
self.graceful_shutdown.cancelled()
}
/// Future that will complete when the application needs to shutdown, either immediately or
/// gracefully. It's a helper so users that don't make the difference between immediate and
/// graceful shutdown don't need to handle two scenarios.
pub(crate) fn shutdown(&self) -> WaitForCancellationFuture<'_> {
self.shutdown.cancelled()
}
/// Shutdown the application immediately.
async fn do_fast_shutdown(&self) {
info!("arbiter has been told to shutdown immediately");
self.unix_handles
.lock()
.await
.iter()
.for_each(Handle::shutdown);
self.net_handles
.lock()
.await
.iter()
.for_each(Handle::shutdown);
info!("all webservers have been shutdown, shutting down the other tasks immediately");
self.fast_shutdown.cancel();
self.shutdown.cancel();
}
/// Shutdown the application gracefully.
async fn do_graceful_shutdown(&self) {
info!("arbiter has been told to shutdown gracefully");
// Match the value in lifecycle/gunicorn.conf.py for graceful shutdown
let timeout = Some(Duration::from_secs(30 + 5));
self.unix_handles
.lock()
.await
.iter()
.for_each(|handle| handle.graceful_shutdown(timeout));
self.net_handles
.lock()
.await
.iter()
.for_each(|handle| handle.graceful_shutdown(timeout));
info!("all webservers have been shutdown, shutting down the other tasks gracefully");
self.graceful_shutdown.cancel();
self.shutdown.cancel();
}
/// Create a new [`broadcast::Receiver`] to listen for signals sent to the main process. This
/// may not include all signals we catch, since some of those will shutdown the application.
pub(crate) fn signals_subscribe(&self) -> broadcast::Receiver<SignalKind> {
self.signals_tx.subscribe()
}
/// Send a value on the config changes watch channel
pub(crate) fn config_changed_send(&self, value: ()) -> Result<()> {
self.config_changed_tx.send(value)?;
Ok(())
}
/// Create a new [`watch::Receiver`] to listen for detected configuration changes.
pub(crate) fn config_changed_subscribe(&self) -> watch::Receiver<()> {
self.config_changed_tx.subscribe()
}
/// Future that will complete when the application needs to shutdown gracefully.
pub(crate) fn gunicorn_ready(&self) -> WaitForCancellationFuture<'_> {
self.gunicorn_ready.cancelled()
}
/// Mark gunicorn as ready
pub(crate) fn mark_gunicorn_ready(&self) {
self.gunicorn_ready.cancel();
}
}

2
src/axum/accept/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub(crate) mod proxy_protocol;
pub(crate) mod tls;

View File

@@ -0,0 +1,86 @@
use std::{io, time::Duration};
use axum::{Extension, middleware::AddExtension};
use axum_server::accept::{Accept, DefaultAcceptor};
use futures::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use tower::Layer as _;
use tracing::instrument;
use crate::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};
#[derive(Clone, Debug)]
pub(crate) struct ProxyProtocolState {
pub(crate) header: Option<Header<'static>>,
}
#[derive(Clone)]
pub(crate) struct ProxyProtocolAcceptor<A = DefaultAcceptor> {
inner: A,
parsing_timeout: Duration,
}
impl ProxyProtocolAcceptor {
pub(crate) fn new() -> Self {
let inner = DefaultAcceptor::new();
#[cfg(not(test))]
let parsing_timeout = Duration::from_secs(10);
// Don't force tests to wait too long
#[cfg(test)]
let parsing_timeout = Duration::from_secs(1);
Self {
inner,
parsing_timeout,
}
}
}
impl Default for ProxyProtocolAcceptor {
fn default() -> Self {
Self::new()
}
}
impl<A> ProxyProtocolAcceptor<A> {
pub(crate) fn acceptor<Acceptor>(self, acceptor: Acceptor) -> ProxyProtocolAcceptor<Acceptor> {
ProxyProtocolAcceptor {
inner: acceptor,
parsing_timeout: self.parsing_timeout,
}
}
}
impl<A, I, S> Accept<I, S> for ProxyProtocolAcceptor<A>
where
A: Accept<I, S> + Clone + Send + 'static,
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
A::Service: Send,
A::Future: Send,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Send + 'static,
{
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
type Service = AddExtension<A::Service, ProxyProtocolState>;
type Stream = ProxyProtocolStream<A::Stream>;
#[instrument(skip_all)]
fn accept(&self, stream: I, service: S) -> Self::Future {
let acceptor = self.inner.clone();
Box::pin(async move {
let (stream, service) = acceptor.accept(stream, service).await?;
let stream = ProxyProtocolStream::new(stream).await?;
let proxy_protocol_state = ProxyProtocolState {
header: stream.header().cloned(),
};
let service = Extension(proxy_protocol_state).layer(service);
Ok((stream, service))
})
}
}

54
src/axum/accept/tls.rs Normal file
View File

@@ -0,0 +1,54 @@
use axum::{Extension, middleware::AddExtension};
use axum_server::{accept::Accept, tls_rustls::RustlsAcceptor};
use futures::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{rustls::pki_types::CertificateDer, server::TlsStream};
use tower::Layer as _;
use tracing::instrument;
#[derive(Clone, Debug)]
pub(crate) struct TlsState {
pub(crate) peer_certificates: Option<Vec<CertificateDer<'static>>>,
}
#[derive(Clone)]
pub(crate) struct TlsAcceptor<A> {
inner: RustlsAcceptor<A>,
}
impl<A> TlsAcceptor<A> {
pub(crate) fn new(inner: RustlsAcceptor<A>) -> Self {
Self { inner }
}
}
impl<A, I, S> Accept<I, S> for TlsAcceptor<A>
where
A: Accept<I, S> + Clone + Send + 'static,
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
A::Service: Send,
A::Future: Send,
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
S: Send + 'static,
{
type Future = BoxFuture<'static, std::io::Result<(Self::Stream, Self::Service)>>;
type Service = AddExtension<A::Service, TlsState>;
type Stream = TlsStream<A::Stream>;
#[instrument(skip_all)]
fn accept(&self, stream: I, service: S) -> Self::Future {
let acceptor = self.inner.clone();
Box::pin(async move {
let (stream, service) = acceptor.accept(stream, service).await?;
let server_conn = stream.get_ref().1;
let tls_state = TlsState {
peer_certificates: server_conn.peer_certificates().map(|c| c.to_owned()),
};
let service = Extension(tls_state).layer(service);
Ok((stream, service))
})
}
}

26
src/axum/error.rs Normal file
View File

@@ -0,0 +1,26 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use eyre::Report;
use tracing::warn;
#[derive(Debug)]
pub(crate) struct AppError(pub(crate) Report);
impl<E> From<E> for AppError
where E: Into<Report>
{
fn from(err: E) -> Self {
Self(err.into())
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
warn!("error occurred: {:?}", self.0);
(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
}
}
pub(crate) type Result<T, E = AppError> = core::result::Result<T, E>;

View File

@@ -0,0 +1,228 @@
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use axum::{
Extension, RequestPartsExt as _,
extract::{ConnectInfo, FromRequestParts, Request},
http::request::Parts,
middleware::Next,
response::Response,
};
use tracing::{Span, instrument};
use crate::axum::{
accept::proxy_protocol::ProxyProtocolState, extract::trusted_proxy::TrustedProxy,
};
#[derive(Clone, Copy, Debug)]
pub(crate) struct ClientIp(pub IpAddr);
impl<S> FromRequestParts<S> for ClientIp
where S: Send + Sync
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(client_ip)| client_ip)
}
}
#[instrument(skip_all)]
async fn extract_client_ip(parts: &mut Parts) -> IpAddr {
let is_trusted = parts
.extract::<TrustedProxy>()
.await
.unwrap_or(TrustedProxy(false))
.0;
if is_trusted {
if let Ok(ip) = client_ip::rightmost_x_forwarded_for(&parts.headers) {
return ip;
}
if let Ok(ip) = client_ip::x_real_ip(&parts.headers) {
return ip;
}
if let Ok(ip) = client_ip::rightmost_forwarded(&parts.headers) {
return ip;
}
if let Ok(Extension(proxy_protocol_state)) =
parts.extract::<Extension<ProxyProtocolState>>().await
&& let Some(header) = &proxy_protocol_state.header
&& let Some(addr) = header.proxied_address()
{
return addr.source.ip();
}
}
if let Ok(ConnectInfo(addr)) = parts.extract::<ConnectInfo<SocketAddr>>().await {
addr.ip()
} else {
// No connect info means we received a request via a Unix socket, hence localhost
// as default
Ipv6Addr::LOCALHOST.into()
}
}
pub(crate) async fn client_ip_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
Span::current().record("remote", client_ip.to_string());
parts.extensions.insert::<ClientIp>(ClientIp(client_ip));
let request = Request::from_parts(parts, body);
next.run(request).await
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
use axum::{body::Body, http::Request};
use super::*;
#[tokio::test]
async fn x_forwarded_for_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-for", "192.0.2.51, 192.0.2.42")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
}
#[tokio::test]
async fn x_real_ip_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-real-ip", "192.0.2.42")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
}
#[tokio::test]
async fn forwarded_header_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "for=192.0.2.42")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
}
#[tokio::test]
async fn from_connect_info() {
let connect_addr: SocketAddr = "192.0.2.42:34932"
.parse()
.expect("Failed to parse socket address");
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.extension(ConnectInfo(connect_addr))
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 42),);
}
#[tokio::test]
async fn headers_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-for", "192.0.2.42")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
}
#[tokio::test]
async fn priority_order() {
// Test that X-Forwarded-For takes priority over other headers when trusted
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-for", "192.0.2.1")
.header("x-real-ip", "192.0.2.2")
.header("forwarded", "for=192.0.2.3")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 1),);
}
#[tokio::test]
async fn no_ip_found() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv6Addr::LOCALHOST);
}
#[tokio::test]
async fn ipv6() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-for", "2001:db8::42")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0x42),);
}
#[tokio::test]
async fn multiple_x_forwarded_for() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-for", "192.0.2.1, 192.0.2.2, 192.0.2.3")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let client_ip = extract_client_ip(&mut parts).await;
assert_eq!(client_ip, Ipv4Addr::new(192, 0, 2, 3),);
}
}

262
src/axum/extract/host.rs Normal file
View File

@@ -0,0 +1,262 @@
use axum::{
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{
header::{FORWARDED, HOST},
request::Parts,
status::StatusCode,
},
middleware::Next,
response::{IntoResponse as _, Response},
};
use forwarded_header_value::ForwardedHeaderValue;
use tracing::{Span, instrument};
use crate::axum::extract::trusted_proxy::TrustedProxy;
const X_FORWARDED_HOST: &str = "X-Forwarded-Host";
#[derive(Clone, Debug)]
pub(crate) struct Host(pub String);
impl<S> FromRequestParts<S> for Host
where S: Send + Sync
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(host)| host)
}
}
#[instrument(skip_all)]
async fn extract_host(parts: &mut Parts) -> Result<String, (StatusCode, &'static str)> {
let is_trusted = parts
.extract::<TrustedProxy>()
.await
.unwrap_or(TrustedProxy(false))
.0;
if is_trusted {
if let Some(host) = parts
.headers
.get(X_FORWARDED_HOST)
.and_then(|host| host.to_str().ok())
{
return Ok(host.to_owned());
}
if let Some(forwarded) = parts.headers.get(FORWARDED)
&& let Ok(forwarded) = forwarded.to_str()
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
{
for stanza in forwarded.iter() {
if let Some(forwarded_host) = &stanza.forwarded_host {
return Ok(forwarded_host.to_owned());
}
}
}
}
if let Some(host) = parts.headers.get(HOST).and_then(|host| host.to_str().ok()) {
return Ok(host.to_owned());
}
if let Some(host) = parts.uri.host() {
Ok(host.to_owned())
} else {
Err((StatusCode::BAD_REQUEST, "missing host header"))
}
}
pub(crate) async fn host_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let host = match extract_host(&mut parts).await {
Ok(host) => host,
Err(err) => return err.into_response(),
};
Span::current().record("host", host.clone());
parts.extensions.insert::<Host>(Host(host));
let request = Request::from_parts(parts, body);
next.run(request).await
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request};
use super::*;
#[tokio::test]
async fn host_header() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("host", "example.com:8080")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com:8080",
);
}
#[tokio::test]
async fn from_uri() {
let (mut parts, _) = Request::builder()
.uri("http://example.com:8080/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn x_forwarded_host_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "forwarded.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"forwarded.example.com",
);
}
#[tokio::test]
async fn forwarded_header_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "host=forwarded.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"forwarded.example.com",
);
}
#[tokio::test]
async fn forwarded_host_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "malicious.example.com")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn forwarded_header_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "host=malicious.example.com")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"example.com",
);
}
#[tokio::test]
async fn priority_order() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-host", "x-forwarded.example.com")
.header("forwarded", "host=forwarded.example.com")
.header("host", "host-header.example.com")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"x-forwarded.example.com",
);
}
#[tokio::test]
async fn no_host_found() {
let (mut parts, _) = Request::builder()
.uri("/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_err());
assert_eq!(result.expect_err("Host extract should fail").0, 400);
}
#[tokio::test]
async fn multiple_forwarded_stanzas() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header(
"forwarded",
"host=first.example.com, host=second.example.com",
)
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let result = extract_host(&mut parts).await;
assert!(result.is_ok());
assert_eq!(
result.expect("Host extraction should succeed"),
"first.example.com",
);
}
}

4
src/axum/extract/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub(crate) mod client_ip;
pub(crate) mod host;
pub(crate) mod scheme;
pub(crate) mod trusted_proxy;

241
src/axum/extract/scheme.rs Normal file
View File

@@ -0,0 +1,241 @@
use axum::{
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{self, header::FORWARDED, request::Parts},
middleware::Next,
response::Response,
};
use forwarded_header_value::{ForwardedHeaderValue, Protocol};
use tracing::{Span, instrument};
use crate::axum::{
accept::{proxy_protocol::ProxyProtocolState, tls::TlsState},
extract::trusted_proxy::TrustedProxy,
};
const X_FORWARDED_PROTO: &str = "X-Forwarded-Proto";
const X_FORWARDED_SCHEME: &str = "X-Forwarded-Scheme";
#[derive(Clone, Debug)]
pub(crate) struct Scheme(pub http::uri::Scheme);
impl<S> FromRequestParts<S> for Scheme
where S: Send + Sync
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(scheme)| scheme)
}
}
#[instrument(skip_all)]
async fn extract_scheme(parts: &mut Parts) -> http::uri::Scheme {
let is_trusted = parts
.extract::<TrustedProxy>()
.await
.unwrap_or(TrustedProxy(false))
.0;
if is_trusted {
if let Some(proto) = parts.headers.get(X_FORWARDED_PROTO)
&& let Ok(proto) = proto.to_str()
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
{
return scheme;
}
if let Some(proto) = parts.headers.get(X_FORWARDED_SCHEME)
&& let Ok(proto) = proto.to_str()
&& let Ok(scheme) = proto.to_lowercase().as_str().try_into()
{
return scheme;
}
if let Some(forwarded) = parts.headers.get(FORWARDED)
&& let Ok(forwarded) = forwarded.to_str()
&& let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded)
{
for stanza in forwarded.iter() {
if let Some(forwarded_proto) = &stanza.forwarded_proto {
let scheme = match forwarded_proto {
Protocol::Http => http::uri::Scheme::HTTP,
Protocol::Https => http::uri::Scheme::HTTPS,
};
return scheme;
}
}
}
if let Ok(Extension(proxy_protocol_state)) =
parts.extract::<Extension<ProxyProtocolState>>().await
&& let Some(header) = &proxy_protocol_state.header
&& let Some(_) = header.ssl()
{
return http::uri::Scheme::HTTPS;
}
}
if parts.extract::<Extension<TlsState>>().await.is_ok() {
http::uri::Scheme::HTTPS
} else {
http::uri::Scheme::HTTP
}
}
pub(crate) async fn scheme_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let scheme = extract_scheme(&mut parts).await;
Span::current().record("scheme", scheme.to_string());
parts.extensions.insert::<Scheme>(Scheme(scheme));
let request = Request::from_parts(parts, body);
next.run(request).await
}
#[cfg(test)]
mod tests {
use axum::{body::Body, http::Request};
use super::*;
#[tokio::test]
async fn x_forwarded_proto_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn x_forwarded_scheme_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-scheme", "https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn forwarded_header_trusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn x_forwarded_proto_untrusted() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "https")
.extension(TrustedProxy(false))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn scheme_from_tls_state() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.extension(TlsState {
peer_certificates: None,
})
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
#[tokio::test]
async fn scheme_defaults_to_http() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn priority_order() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "http")
.header("x-forwarded-scheme", "https")
.header("forwarded", "proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn multiple_forwarded_stanzas() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("forwarded", "proto=http, proto=https")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTP,);
}
#[tokio::test]
async fn test_scheme_case_insensitive() {
let (mut parts, _) = Request::builder()
.uri("http://example.com/path")
.header("x-forwarded-proto", "HTTPS")
.extension(TrustedProxy(true))
.body(Body::empty())
.expect("Failed to create request")
.into_parts();
let scheme = extract_scheme(&mut parts).await;
assert_eq!(scheme, http::uri::Scheme::HTTPS,);
}
}

View File

@@ -0,0 +1,59 @@
use std::net::SocketAddr;
use axum::{
Extension, RequestPartsExt as _,
extract::{ConnectInfo, FromRequestParts, Request},
http::request::Parts,
middleware::Next,
response::Response,
};
use tracing::{instrument, trace};
use crate::config;
#[derive(Clone, Copy, Debug)]
pub(crate) struct TrustedProxy(pub bool);
impl<S> FromRequestParts<S> for TrustedProxy
where S: Send + Sync
{
type Rejection = <Extension<Self> as FromRequestParts<S>>::Rejection;
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
Extension::<Self>::from_request_parts(parts, state)
.await
.map(|Extension(trusted_proxy)| trusted_proxy)
}
}
#[instrument(skip_all)]
async fn extract_trusted_proxy(parts: &mut Parts) -> bool {
if let Ok(ConnectInfo(addr)) = parts.extract::<ConnectInfo<SocketAddr>>().await {
let trusted_proxy_cidrs = &config::get().listen.trusted_proxy_cidrs;
for trusted_net in trusted_proxy_cidrs {
if trusted_net.contains(&addr.ip()) {
trace!(
?addr,
?trusted_net,
"connection is now considered coming from a trusted proxy"
);
return true;
}
}
}
false
}
pub(crate) async fn trusted_proxy_middleware(request: Request, next: Next) -> Response {
let (mut parts, body) = request.into_parts();
let trusted_proxy = extract_trusted_proxy(&mut parts).await;
parts
.extensions
.insert::<TrustedProxy>(TrustedProxy(trusted_proxy));
let request = Request::from_parts(parts, body);
next.run(request).await
}

6
src/axum/mod.rs Normal file
View File

@@ -0,0 +1,6 @@
pub(crate) mod accept;
pub(crate) mod error;
pub(crate) mod extract;
pub(crate) mod router;
pub(crate) mod server;
pub(crate) mod trace;

39
src/axum/router.rs Normal file
View File

@@ -0,0 +1,39 @@
use axum::{Router, http::status::StatusCode, middleware::from_fn};
use tower::ServiceBuilder;
use tower_http::timeout::TimeoutLayer;
use crate::{
axum::{
extract::{
client_ip::client_ip_middleware, host::host_middleware, scheme::scheme_middleware,
trusted_proxy::trusted_proxy_middleware,
},
trace::{span_middleware, tracing_middleware},
},
config,
};
#[inline]
pub(crate) fn wrap_router(router: Router, with_trace: bool) -> Router {
let config = config::get();
let timeout = durstr::parse(&config.web.timeout_http_read_header)
.expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_read).expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_write).expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_idle).expect("Invalid duration in http timeout");
let service_builder = ServiceBuilder::new()
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
timeout,
))
.layer(from_fn(span_middleware))
.layer(from_fn(trusted_proxy_middleware))
.layer(from_fn(client_ip_middleware))
.layer(from_fn(scheme_middleware))
.layer(from_fn(host_middleware));
if with_trace {
router.layer(service_builder.layer(from_fn(tracing_middleware)))
} else {
router.layer(service_builder)
}
}

119
src/axum/server.rs Normal file
View File

@@ -0,0 +1,119 @@
use std::{net, os::unix};
use axum::Router;
use axum_server::{
Handle,
accept::DefaultAcceptor,
tls_rustls::{RustlsAcceptor, RustlsConfig},
};
use eyre::Result;
use tracing::info;
use crate::{
arbiter::{Arbiter, Tasks},
axum::accept::{proxy_protocol::ProxyProtocolAcceptor, tls::TlsAcceptor},
};
async fn run_plain(
arbiter: Arbiter,
name: &str,
router: Router,
addr: net::SocketAddr,
) -> Result<()> {
info!(addr = addr.to_string(), "starting {name} server");
let handle = Handle::new();
arbiter.add_net_handle(handle.clone()).await;
axum_server::Server::bind(addr)
.handle(handle)
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
.await?;
Ok(())
}
pub(crate) fn start_plain(
tasks: &mut Tasks,
name: &'static str,
router: Router,
addr: net::SocketAddr,
) -> Result<()> {
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::run_plain({name}, {addr})", module_path!()))
.spawn(run_plain(arbiter, name, router, addr))?;
Ok(())
}
pub(crate) async fn run_unix(
arbiter: Arbiter,
name: &str,
router: Router,
addr: unix::net::SocketAddr,
) -> Result<()> {
info!(addr = ?addr, "starting {name} server");
let handle = Handle::new();
arbiter.add_unix_handle(handle.clone()).await;
axum_server::Server::bind(addr)
.handle(handle)
.serve(router.into_make_service())
.await?;
Ok(())
}
pub(crate) fn start_unix(
tasks: &mut Tasks,
name: &'static str,
router: Router,
addr: unix::net::SocketAddr,
) -> Result<()> {
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::run_unix({name}, {addr:?})", module_path!()))
.spawn(run_unix(arbiter, name, router, addr))?;
Ok(())
}
async fn run_tls(
arbiter: Arbiter,
name: &str,
router: Router,
addr: net::SocketAddr,
config: RustlsConfig,
) -> Result<()> {
info!(addr = addr.to_string(), "starting {name} server");
let handle = Handle::new();
arbiter.add_net_handle(handle.clone()).await;
axum_server::Server::bind(addr)
.acceptor(TlsAcceptor::new(RustlsAcceptor::new(config).acceptor(
ProxyProtocolAcceptor::new().acceptor(DefaultAcceptor::new()),
)))
.handle(handle)
.serve(router.into_make_service_with_connect_info::<net::SocketAddr>())
.await?;
Ok(())
}
pub(crate) fn start_tls(
tasks: &mut Tasks,
name: &'static str,
router: Router,
addr: net::SocketAddr,
config: RustlsConfig,
) -> Result<()> {
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::run_tls({name}, {addr})", module_path!()))
.spawn(run_tls(arbiter, name, router, addr, config))?;
Ok(())
}

48
src/axum/trace.rs Normal file
View File

@@ -0,0 +1,48 @@
use std::collections::HashMap;
use axum::{extract::Request, middleware::Next, response::Response};
use tokio::time::Instant;
use tracing::{Instrument as _, field, info, info_span, trace};
use crate::config;
pub(crate) async fn span_middleware(request: Request, next: Next) -> Response {
let config = config::get();
let http_headers = request
.headers()
.iter()
.filter(|(name, _)| {
for header in &config.log.http_headers {
if header.eq_ignore_ascii_case(name.as_str()) {
return true;
}
}
false
})
.map(|(name, value)| (name.to_string().to_lowercase().replace('-', "_"), value))
.collect::<HashMap<_, _>>();
let span = info_span!(
"request",
path = %request.uri(),
method = %request.method(),
remote = field::Empty,
scheme = field::Empty,
host = field::Empty,
http_headers = ?http_headers,
);
next.run(request).instrument(span).await
}
pub(crate) async fn tracing_middleware(request: Request, next: Next) -> Response {
let event = request.uri().clone();
trace!("request start");
let start = Instant::now();
let response = next.run(request).await;
let runtime = start.elapsed();
let status = response.status().as_u16();
info!(status = status, runtime = runtime.as_millis(), "{event}");
response
}

1
src/brands/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub(crate) mod tls;

134
src/brands/tls.rs Normal file
View File

@@ -0,0 +1,134 @@
use std::{
collections::{HashMap, hash_map::Entry},
sync::Arc,
};
use eyre::{Report, Result};
use rustls::{
RootCertStore,
crypto::CryptoProvider,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
server::ClientHello,
sign::CertifiedKey,
};
use crate::db;
#[derive(Debug)]
struct Brand {
domain: String,
default: bool,
web_certificate: Arc<CertifiedKey>,
}
#[derive(Debug)]
pub(crate) struct CertResolver {
brands: Vec<Brand>,
}
impl CertResolver {
pub(crate) fn resolve(&self, client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name()?;
let mut best = None;
for brand in &self.brands {
if best.is_none() && brand.default {
best = Some(Arc::clone(&brand.web_certificate));
}
if server_name == brand.domain || server_name.ends_with(&format!(".{}", brand.domain)) {
best = Some(Arc::clone(&brand.web_certificate));
}
}
best
}
}
pub(crate) async fn make_cert_managers() -> Result<(CertResolver, RootCertStore)> {
#[derive(sqlx::FromRow)]
struct BrandRow {
brand_uuid: uuid::Uuid,
domain: String,
default: bool,
web_cert_data: Option<String>,
web_cert_key: Option<String>,
client_cert_data: Option<String>,
}
let rows = sqlx::query_as::<_, BrandRow>(
"
SELECT
b.brand_uuid,
b.domain,
b.default,
wc.certificate_data AS web_cert_data,
wc.key_data AS web_cert_key,
cc.certificate_data AS client_cert_data
FROM authentik_brands_brand b
LEFT JOIN authentik_crypto_certificatekeypair wc
ON wc.kp_uuid = b.web_certificate_id
LEFT JOIN authentik_brands_brand_client_certificates bcc
ON bcc.brand_id = b.brand_uuid
LEFT JOIN authentik_crypto_certificatekeypair cc
ON cc.kp_uuid = bcc.certificatekeypair_id
",
)
.fetch_all(db::get())
.await?;
let (brands, roots) = tokio::task::spawn_blocking(|| {
let mut brands = HashMap::new();
let mut roots = RootCertStore::empty();
for row in rows {
let BrandRow {
brand_uuid,
domain,
default,
web_cert_data,
web_cert_key,
client_cert_data,
} = row;
if let (Some(certificate_data), Some(key_data)) = (web_cert_data, web_cert_key)
&& let Entry::Vacant(e) = brands.entry(brand_uuid)
{
let brand = Brand {
domain,
default,
web_certificate: {
let cert_chain =
CertificateDer::pem_reader_iter(certificate_data.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
let key_der = PrivateKeyDer::from_pem_reader(key_data.as_bytes())?;
let provider =
CryptoProvider::get_default().expect("no rustls provider installed");
Arc::new(CertifiedKey::new(
cert_chain,
provider.key_provider.load_private_key(key_der)?,
))
},
};
e.insert(brand);
}
if let Some(certificate_data) = client_cert_data {
let cert_chain = CertificateDer::pem_reader_iter(certificate_data.as_bytes())
.collect::<Result<Vec<_>, _>>()?;
for cert in cert_chain {
roots.add(cert)?;
}
}
}
Ok::<_, Report>((brands, roots))
})
.await??;
Ok((
CertResolver {
brands: brands.into_values().collect(),
},
roots,
))
}

244
src/config/mod.rs Normal file
View File

@@ -0,0 +1,244 @@
use std::{
env,
fs::{self, read_to_string},
path::PathBuf,
sync::{Arc, OnceLock},
};
use arc_swap::{ArcSwap, Guard};
use eyre::Result;
use notify::{RecommendedWatcher, Watcher as _};
use serde_json::{Map, Value};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
pub(crate) mod schema;
pub(crate) use schema::Config;
use url::Url;
use crate::arbiter::{Arbiter, Tasks};
static DEFAULT_CONFIG: &str = include_str!("../../authentik/lib/default.yml");
static CONFIG_MANAGER: OnceLock<ConfigManager> = OnceLock::new();
fn config_paths() -> Vec<PathBuf> {
let mut config_paths = vec![
PathBuf::from("/etc/authentik/config.yml"),
PathBuf::from(""),
];
if let Ok(workspace) = env::var("WORKSPACE_DIR") {
let _ = env::set_current_dir(workspace);
}
if let Ok(paths) = glob::glob("/etc/authentik/config.d/*.yml") {
config_paths.extend(paths.filter_map(Result::ok));
}
let environment = env::var("AUTHENTIK_ENV").unwrap_or_else(|_| "local".to_owned());
let mut computed_paths = Vec::new();
for path in config_paths {
if let Ok(metadata) = fs::metadata(&path) {
if !metadata.is_dir() {
computed_paths.push(path);
}
} else {
let env_paths = vec![
path.join(format!("{environment}.yml")),
path.join(format!("{environment}.env.yml")),
];
for env_path in env_paths {
if let Ok(metadata) = fs::metadata(&env_path)
&& !metadata.is_dir()
{
computed_paths.push(env_path);
}
}
}
}
computed_paths
}
impl Config {
fn load_raw(config_paths: &[PathBuf]) -> Result<Value> {
let mut builder = config::Config::builder().add_source(config::File::from_str(
DEFAULT_CONFIG,
config::FileFormat::Yaml,
));
for path in config_paths {
builder = builder
.add_source(config::File::from(path.as_path()).format(config::FileFormat::Yaml));
}
builder = builder.add_source(config::Environment::with_prefix("AUTHENTIK"));
let config = builder.build()?;
let raw = config.try_deserialize::<Value>()?;
Ok(raw)
}
fn expand_value(value: &str) -> (String, Option<PathBuf>) {
let value = value.trim();
if let Ok(uri) = Url::parse(value) {
let fallback = uri.query().unwrap_or("").to_owned();
match uri.scheme() {
"file" => {
let path = uri.path();
match read_to_string(path).map(|s| s.trim().to_owned()) {
Ok(value) => return (value, Some(PathBuf::from(path))),
Err(err) => {
error!("failed to read config value from {path}: {err}");
return (fallback, Some(PathBuf::from(path)));
}
}
}
"env" => {
if let Some(var) = uri.host_str() {
if let Ok(value) = env::var(var) {
return (value, None);
}
return (fallback, None);
}
}
_ => {}
}
}
(value.to_owned(), None)
}
fn expand(mut raw: Value) -> (Value, Vec<PathBuf>) {
let mut file_paths = Vec::new();
let value = match &mut raw {
Value::String(s) => {
let (v, path) = Self::expand_value(s);
if let Some(path) = path {
file_paths.push(path);
}
Value::String(v)
}
Value::Array(arr) => {
let mut res = Vec::with_capacity(arr.len());
for v in arr {
let (expanded, paths) = Self::expand(v.clone());
file_paths.extend(paths);
res.push(expanded);
}
Value::Array(res)
}
Value::Object(map) => {
let mut res = Map::with_capacity(map.len());
for (k, v) in map {
let (expanded, paths) = Self::expand(v.clone());
file_paths.extend(paths);
res.insert(k.clone(), expanded);
}
Value::Object(res)
}
_ => raw,
};
(value, file_paths)
}
fn load(config_paths: &[PathBuf]) -> Result<(Self, Vec<PathBuf>)> {
let raw = Self::load_raw(config_paths)?;
let (expanded, file_paths) = Self::expand(raw);
let config: Self = serde_json::from_value(expanded)?;
Ok((config, file_paths))
}
}
pub(crate) struct ConfigManager {
config: ArcSwap<Config>,
config_paths: Vec<PathBuf>,
watch_paths: Vec<PathBuf>,
}
impl ConfigManager {
pub(crate) fn init() -> Result<()> {
info!("loading config");
let config_paths = config_paths();
let mut watch_paths = config_paths.clone();
let (config, other_paths) = Config::load(&config_paths)?;
watch_paths.extend(other_paths);
let manager = Self {
config: ArcSwap::from_pointee(config),
config_paths,
watch_paths,
};
CONFIG_MANAGER.get_or_init(|| manager);
info!("config loaded");
Ok(())
}
pub(crate) fn run(tasks: &mut Tasks) -> Result<()> {
info!("starting config file watcher");
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::watch_config", module_path!()))
.spawn(watch_config(arbiter))?;
Ok(())
}
}
async fn watch_config(arbiter: Arbiter) -> Result<()> {
let (tx, mut rx) = mpsc::channel(100);
let mut watcher = RecommendedWatcher::new(
move |res: notify::Result<notify::Event>| {
if let Ok(event) = res
&& let notify::EventKind::Modify(_) = &event.kind
{
let _ = tx.blocking_send(());
}
},
notify::Config::default(),
)?;
let watch_paths = &CONFIG_MANAGER
.get()
.expect("failed to get config, has it been initialized?")
.watch_paths;
for path in watch_paths {
watcher.watch(path.as_ref(), notify::RecursiveMode::NonRecursive)?;
}
info!("config file watcher started on paths: {:?}", watch_paths);
loop {
tokio::select! {
res = rx.recv() => {
info!("a configuration file changed, reloading config");
if res.is_none() {
break;
}
let manager = CONFIG_MANAGER.get().expect("failed to get config, has it been initialized?");
match tokio::task::spawn_blocking(|| Config::load(&manager.config_paths)).await? {
Ok((new_config, _)) => {
info!("configuration reloaded");
manager.config.store(Arc::new(new_config));
if let Err(err) = arbiter.config_changed_send(()) {
warn!("failed to notify of config change, aborting: {err:?}");
break;
}
}
Err(err) => {
warn!("failed to reload config, continuing with previous config: {err:?}");
}
}
},
() = arbiter.shutdown() => break,
}
}
info!("stopping config file watcher");
Ok(())
}
pub(crate) fn get() -> Guard<Arc<Config>> {
let manager = CONFIG_MANAGER
.get()
.expect("failed to get config, has it been initialized?");
manager.config.load()
}

146
src/config/schema.rs Normal file
View File

@@ -0,0 +1,146 @@
use std::{collections::HashMap, net::SocketAddr, num::NonZeroUsize, path::PathBuf};
use ipnet::IpNet;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct Config {
pub(crate) postgresql: PostgreSQLConfig,
pub(crate) listen: ListenConfig,
pub(crate) http_timeout: u32,
pub(crate) debug: bool,
#[serde(default)]
pub(crate) secret_key: String,
pub(crate) log_level: String,
pub(crate) log: LogConfig,
pub(crate) error_reporting: ErrorReportingConfig,
pub(crate) outposts: OutpostsConfig,
pub(crate) cookie_domain: Option<String>,
pub(crate) compliance: ComplianceConfig,
pub(crate) blueprints_dir: PathBuf,
pub(crate) cert_discovery_dir: PathBuf,
pub(crate) web: WebConfig,
pub(crate) worker: WorkerConfig,
pub(crate) storage: StorageConfig,
// Outpost specific config
// These are only relevant for outposts, and cannot be set via YAML
// They are loaded via this config loader to support file:// schemas
pub(crate) authentik_host: Option<String>,
pub(crate) authentik_host_browser: Option<String>,
pub(crate) authentik_token: Option<String>,
pub(crate) authentik_insecure: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct PostgreSQLConfig {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) user: String,
pub(crate) password: String,
pub(crate) name: String,
pub(crate) sslmode: String,
pub(crate) sslrootcert: Option<String>,
pub(crate) sslcert: Option<String>,
pub(crate) sslkey: Option<String>,
pub(crate) conn_max_age: Option<u64>,
pub(crate) conn_health_checks: bool,
pub(crate) default_schema: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ListenConfig {
pub(crate) http: Vec<SocketAddr>,
pub(crate) https: Vec<SocketAddr>,
pub(crate) ldap: Vec<SocketAddr>,
pub(crate) ldaps: Vec<SocketAddr>,
pub(crate) radius: Vec<SocketAddr>,
pub(crate) metrics: Vec<SocketAddr>,
pub(crate) debug: SocketAddr,
pub(crate) trusted_proxy_cidrs: Vec<IpNet>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct LogConfig {
pub(crate) http_headers: Vec<String>,
pub(crate) rust_log: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ErrorReportingConfig {
pub(crate) enabled: bool,
pub(crate) sentry_dsn: Option<String>,
pub(crate) environment: String,
pub(crate) send_pii: bool,
pub(crate) sample_rate: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct OutpostsConfig {
pub(crate) disable_embedded_outpost: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ComplianceConfig {
pub(crate) fips: ComplianceFipsConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct ComplianceFipsConfig {
pub(crate) enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct WebConfig {
pub(crate) workers: usize,
pub(crate) threads: usize,
pub(crate) path: String,
pub(crate) timeout_http_read_header: String,
pub(crate) timeout_http_read: String,
pub(crate) timeout_http_write: String,
pub(crate) timeout_http_idle: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct WorkerConfig {
pub(crate) processes: NonZeroUsize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct StorageConfig {
pub(crate) backend: String,
pub(crate) file: StorageFileConfig,
pub(crate) media: Option<StorageOverrideConfig>,
pub(crate) reports: Option<StorageOverrideConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct StorageFileConfig {
pub(crate) path: PathBuf,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub(crate) struct StorageOverrideConfig {
pub(crate) backend: Option<String>,
pub(crate) file: Option<StorageFileOverrideConfig>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub(crate) struct StorageFileOverrideConfig {
pub(crate) path: Option<PathBuf>,
}

110
src/db.rs Normal file
View File

@@ -0,0 +1,110 @@
use std::{str::FromStr as _, sync::OnceLock, time::Duration};
use eyre::Result;
use sqlx::{
Executor as _, PgPool,
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
};
use tracing::{info, log::LevelFilter, trace};
use crate::{
arbiter::{Arbiter, Tasks},
authentik_full_version, config,
mode::Mode,
};
static DB: OnceLock<PgPool> = OnceLock::new();
fn get_connect_opts() -> Result<PgConnectOptions> {
let config = config::get();
let mut opts = PgConnectOptions::new()
.application_name(&format!(
"authentik-{}@{}",
Mode::get(),
authentik_full_version()
))
.host(&config.postgresql.host)
.port(config.postgresql.port)
.username(&config.postgresql.user)
.password(&config.postgresql.password)
.database(&config.postgresql.name)
.ssl_mode(PgSslMode::from_str(&config.postgresql.sslmode)?);
if let Some(sslrootcert) = &config.postgresql.sslrootcert {
opts = opts.ssl_root_cert_from_pem(sslrootcert.as_bytes().to_vec());
}
if let Some(sslcert) = &config.postgresql.sslcert {
opts = opts.ssl_client_cert_from_pem(sslcert.as_bytes());
}
if let Some(sslkey) = &config.postgresql.sslkey {
opts = opts.ssl_client_key_from_pem(sslkey.as_bytes());
}
Ok(opts)
}
async fn update_connect_opts_on_config_change(arbiter: Arbiter) -> Result<()> {
let mut config_changed_rx = arbiter.config_changed_subscribe();
info!("starting database watcher for config changes");
loop {
tokio::select! {
res = config_changed_rx.changed() => {
if let Err(err) = res {
trace!("error receiving config changes: {err:?}");
break;
}
trace!("config change received, refreshing database connection options");
let db = get();
db.set_connect_options(get_connect_opts()?);
},
() = arbiter.shutdown() => break,
}
}
info!("stopping database watcher for config changes");
Ok(())
}
pub(crate) async fn init(tasks: &mut Tasks) -> Result<()> {
info!("initializing database pool");
let options = get_connect_opts()?;
let config = config::get();
let pool_options = PgPoolOptions::new()
.min_connections(1)
.max_connections(4)
.acquire_time_level(LevelFilter::Trace)
.max_lifetime(config.postgresql.conn_max_age.map(Duration::from_secs))
.test_before_acquire(config.postgresql.conn_health_checks)
.after_connect(|conn, _meta| {
Box::pin(async move {
let application_name =
format!("authentik-{}@{}", Mode::get(), authentik_full_version());
let default_schema = &config::get().postgresql.default_schema;
let query = format!(
"SET application_name = '{application_name}'; SET search_path = \
'{default_schema}';"
);
conn.execute(query.as_str()).await?;
Ok(())
})
});
let pool = pool_options.connect_with(options).await?;
DB.get_or_init(|| pool);
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!(
"{}::update_connect_opts_on_config_change",
module_path!(),
))
.spawn(update_connect_opts_on_config_change(arbiter))?;
info!("database pool initialized");
Ok(())
}
pub(crate) fn get() -> &'static PgPool {
DB.get()
.expect("failed to get db, has it been initialized?")
}

220
src/main.rs Normal file
View File

@@ -0,0 +1,220 @@
use std::{
process::exit,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use ::tracing::{error, info, trace};
use argh::FromArgs;
use eyre::{Result, eyre};
use crate::{arbiter::Tasks, config::ConfigManager, mode::Mode};
mod arbiter;
mod axum;
#[cfg(feature = "core")]
mod brands;
mod config;
#[cfg(feature = "core")]
mod db;
mod metrics;
mod mode;
#[cfg(feature = "proxy")]
mod proxy;
#[cfg(feature = "core")]
mod server;
mod tokio;
mod tracing;
#[cfg(feature = "core")]
mod worker;
const VERSION: &str = env!("CARGO_PKG_VERSION");
pub(crate) fn authentik_build_hash(fallback: Option<String>) -> String {
std::env::var("GIT_BUILD_HASH").unwrap_or_else(|_| fallback.unwrap_or_default())
}
pub(crate) fn authentik_full_version() -> String {
let build_hash = authentik_build_hash(None);
if build_hash.is_empty() {
VERSION.to_owned()
} else {
format!("{VERSION}+{build_hash}")
}
}
pub(crate) fn authentik_user_agent() -> String {
format!("authentik@{}", authentik_full_version())
}
#[derive(Debug, FromArgs, PartialEq)]
/// The authentication glue you need.
struct Cli {
#[argh(subcommand)]
command: Command,
}
#[derive(Debug, FromArgs, PartialEq)]
#[argh(subcommand)]
enum Command {
#[cfg(feature = "core")]
AllInOne(AllInOne),
#[cfg(feature = "core")]
Server(server::Cli),
#[cfg(feature = "core")]
Worker(worker::Cli),
#[cfg(feature = "proxy")]
Proxy(proxy::Cli),
#[cfg(feature = "core")]
Manage(Manage),
}
#[derive(Debug, FromArgs, PartialEq)]
/// Run the authentik server and worker.
#[argh(subcommand, name = "allinone")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
struct AllInOne {}
#[derive(Debug, FromArgs, PartialEq)]
/// authentik django's management command.
#[argh(subcommand, name = "manage")]
struct Manage {
#[argh(positional, greedy)]
args: Vec<String>,
}
fn main() -> Result<()> {
let tracing_crude = tracing::install_crude();
info!(version = env!("CARGO_PKG_VERSION"), "authentik is starting");
let cli: Cli = argh::from_env();
match &cli.command {
#[cfg(feature = "core")]
Command::AllInOne(_) => Mode::set(Mode::AllInOne)?,
#[cfg(feature = "core")]
Command::Server(_) => Mode::set(Mode::Server)?,
#[cfg(feature = "core")]
Command::Worker(_) => Mode::set(Mode::Worker)?,
#[cfg(feature = "proxy")]
Command::Proxy(_) => Mode::set(Mode::Proxy)?,
#[cfg(feature = "core")]
Command::Manage(args) => {
let mut process = std::process::Command::new("python")
.args(["-m", "manage"])
.args(&args.args)
.spawn()?;
let status = process.wait()?;
if let Some(code) = status.code() {
exit(code);
}
return Ok(());
}
}
trace!("installing error formatting");
color_eyre::install()?;
trace!("installing rustls crypto provider");
#[expect(
clippy::unwrap_in_result,
reason = "result type does not implement Error"
)]
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.expect("Failed to install rustls provider");
#[cfg(feature = "core")]
if Mode::is_core() {
if std::env::var("PROMETHEUS_MULTIPROC_DIR").is_err() {
let dir = std::env::temp_dir().join("authentik_prometheus_tmp");
std::fs::create_dir_all(&dir)?;
#[expect(unsafe_code, reason = "see safety comment below")]
// SAFETY: there is only one thread at this point, so this is safe.
unsafe {
std::env::set_var("PROMETHEUS_MULTIPROC_DIR", dir);
}
trace!(
env = std::env::var("PROMETHEUS_MULTIPROC_DIR").unwrap_or_default(),
"setting PROMETHEUS_MULTIPROC_DIR"
);
} else {
trace!("PROMETHEUS_MULTIPROC_DIR already set");
}
trace!("initializing Python");
pyo3::Python::initialize();
trace!("Python initialized");
}
ConfigManager::init()?;
let _sentry = config::get()
.error_reporting
.enabled
.then(tracing::sentry::install);
tracing::install()?;
drop(tracing_crude);
::tokio::runtime::Builder::new_multi_thread()
.thread_name_fn(|| {
static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("tokio-{id}")
})
.enable_all()
.build()?
.block_on(async {
let mut tasks = Tasks::new()?;
ConfigManager::run(&mut tasks)?;
let metrics = metrics::run(&mut tasks)?;
#[cfg(feature = "core")]
if Mode::is_core() {
db::init(&mut tasks).await?;
}
match cli.command {
#[cfg(feature = "core")]
Command::AllInOne(_) => {
let workers = worker::run(worker::Cli::default(), &mut tasks)?;
metrics.workers.store(Some(Arc::clone(&workers)));
let server = server::run(server::Cli::default(), &mut tasks)?;
server.workers.store(Some(workers));
metrics.server.store(Some(server));
}
#[cfg(feature = "core")]
Command::Server(args) => {
let server = server::run(args, &mut tasks)?;
metrics.server.store(Some(server));
}
#[cfg(feature = "core")]
Command::Worker(args) => {
let workers = worker::run(args, &mut tasks)?;
metrics.workers.store(Some(workers));
}
#[cfg(feature = "proxy")]
Command::Proxy(args) => proxy::run(args, &mut tasks)?,
#[cfg(feature = "core")]
Command::Manage(_) => unreachable!(),
}
let errors = tasks.run().await;
if errors.is_empty() {
info!("authentik exiting");
Ok(())
} else {
error!("authentik encountered errors: {:?}", errors);
Err(eyre!("Errors encountered: {:?}", errors))
}
})
}

91
src/metrics/handlers.rs Normal file
View File

@@ -0,0 +1,91 @@
use std::sync::Arc;
use axum::{body::Body, extract::State, http::StatusCode, response::Response};
#[cfg(feature = "core")]
use crate::mode::Mode;
use crate::{axum::error::Result, metrics::Metrics};
pub(super) async fn metrics_handler(State(state): State<Arc<Metrics>>) -> Result<Response> {
let mut metrics = Vec::new();
state.prometheus.render_to_write(&mut metrics)?;
#[cfg(feature = "core")]
if Mode::is_core() {
use axum::http::{Request, header::HOST};
if [Mode::AllInOne, Mode::Server].contains(&Mode::get()) {
let req = Request::builder()
.method("GET")
.uri("http://localhost:8000/-/metrics/")
.header(HOST, "localhost")
.body(Body::from(""));
if let Ok(req) = req
&& let Some(server) = state.server.load_full()
{
let _ = server.client.request(req).await;
}
} else if [Mode::Worker].contains(&Mode::get()) {
let req = Request::builder()
.method("GET")
.uri("http://localhost:8000/-/metrics/")
.header(HOST, "localhost")
.body(Body::from(""));
if let Ok(req) = req
&& let Some(workers) = state.workers.load_full()
{
let _ = workers.client.request(req).await;
}
}
metrics.extend(tokio::task::spawn_blocking(python::get_python_metrics).await??);
}
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "text/plain; version=1.0.0; charset=utf-8")
.body(Body::from(metrics))?)
}
#[cfg(feature = "core")]
mod python {
use eyre::{Report, Result};
use pyo3::{
IntoPyObjectExt as _,
ffi::c_str,
prelude::*,
types::{PyBytes, PyDict},
};
pub(super) fn get_python_metrics() -> Result<Vec<u8>> {
let metrics = Python::attach(|py| {
let locals = PyDict::new(py);
Python::run(
py,
c_str!(
r#"
from prometheus_client import (
CollectorRegistry,
generate_latest,
multiprocess,
)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)
output = generate_latest(registry)
"#
),
None,
Some(&locals),
)?;
let metrics = locals
.get_item("output")?
.unwrap_or(PyBytes::new(py, &[]).into_bound_py_any(py)?)
.cast::<PyBytes>()
.map_or_else(|_| PyBytes::new(py, &[]), |v| v.to_owned())
.as_bytes()
.to_owned();
Ok::<_, Report>(metrics)
})?;
Ok::<_, Report>(metrics)
}
}

84
src/metrics/mod.rs Normal file
View File

@@ -0,0 +1,84 @@
use std::{env::temp_dir, os::unix, sync::Arc, time::Duration};
use arc_swap::ArcSwapOption;
use axum::{Router, routing::any};
use eyre::Result;
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
use crate::{
arbiter::{Arbiter, Tasks},
axum::{router::wrap_router, server},
config,
};
#[cfg(feature = "core")]
use crate::{server::Server, worker::Workers};
mod handlers;
pub(crate) struct Metrics {
prometheus: PrometheusHandle,
#[cfg(feature = "core")]
pub(crate) server: ArcSwapOption<Server>,
#[cfg(feature = "core")]
pub(crate) workers: ArcSwapOption<Workers>,
}
impl Metrics {
fn new() -> Result<Self> {
let prometheus = PrometheusBuilder::new()
.with_recommended_naming(true)
.install_recorder()?;
Ok(Self {
prometheus,
#[cfg(feature = "core")]
server: ArcSwapOption::empty(),
#[cfg(feature = "core")]
workers: ArcSwapOption::empty(),
})
}
}
async fn run_upkeep(arbiter: Arbiter, state: Arc<Metrics>) -> Result<()> {
loop {
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(5)) => {
let state_clone = Arc::clone(&state);
tokio::task::spawn_blocking(move || state_clone.prometheus.run_upkeep()).await?;
},
() = arbiter.shutdown() => return Ok(())
}
}
}
fn build_router(state: Arc<Metrics>) -> Router {
wrap_router(
Router::new()
.fallback(any(handlers::metrics_handler))
.with_state(state),
true,
)
}
pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
let arbiter = tasks.arbiter();
let metrics = Arc::new(Metrics::new()?);
let router = build_router(Arc::clone(&metrics));
tasks
.build_task()
.name(&format!("{}::run_upkeep", module_path!(),))
.spawn(run_upkeep(arbiter, Arc::clone(&metrics)))?;
for addr in config::get().listen.metrics.iter().copied() {
server::start_plain(tasks, "metrics", router.clone(), addr)?;
}
server::start_unix(
tasks,
"metrics",
router,
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik-metrics.sock"))?,
)?;
Ok(metrics)
}

78
src/mode.rs Normal file
View File

@@ -0,0 +1,78 @@
use std::{
env,
path::PathBuf,
sync::atomic::{AtomicU8, Ordering},
};
use eyre::Result;
static MODE: AtomicU8 = AtomicU8::new(0);
fn mode_path() -> PathBuf {
env::temp_dir().join("authentik-mode")
}
#[derive(PartialEq)]
#[repr(u8)]
pub(crate) enum Mode {
#[cfg(feature = "core")]
AllInOne = 0,
#[cfg(feature = "core")]
Server = 1,
#[cfg(feature = "core")]
Worker = 2,
#[cfg(feature = "proxy")]
Proxy = 3,
}
impl std::fmt::Display for Mode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#[cfg(feature = "core")]
Self::AllInOne => write!(f, "allinone"),
#[cfg(feature = "core")]
Self::Server => write!(f, "server"),
#[cfg(feature = "core")]
Self::Worker => write!(f, "worker"),
#[cfg(feature = "proxy")]
Self::Proxy => write!(f, "proxy"),
}
}
}
impl From<Mode> for u8 {
#[expect(clippy::as_conversions, reason = "repr of enum is u8")]
fn from(value: Mode) -> Self {
value as Self
}
}
impl Mode {
pub(crate) fn get() -> Self {
match MODE.load(Ordering::Relaxed) {
#[cfg(feature = "core")]
0 => Self::AllInOne,
#[cfg(feature = "core")]
1 => Self::Server,
#[cfg(feature = "core")]
2 => Self::Worker,
#[cfg(feature = "proxy")]
3 => Self::Proxy,
_ => unreachable!(),
}
}
pub(crate) fn set(mode: Self) -> Result<()> {
std::fs::write(mode_path(), mode.to_string())?;
MODE.store(mode.into(), Ordering::SeqCst);
Ok(())
}
pub(crate) fn is_core() -> bool {
match Self::get() {
#[cfg(feature = "core")]
Self::AllInOne | Self::Server | Self::Worker => true,
_ => false,
}
}
}

48
src/proxy/mod.rs Normal file
View File

@@ -0,0 +1,48 @@
use argh::FromArgs;
use axum::extract::Request;
use eyre::Result;
use crate::arbiter::{Arbiter, Tasks};
#[derive(Debug, FromArgs, PartialEq)]
/// Run the authentik proxy outpost.
#[argh(subcommand, name = "proxy")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(crate) struct Cli {}
pub(crate) mod tls {
use std::sync::Arc;
use rustls::{server::ClientHello, sign::CertifiedKey};
#[derive(Debug)]
pub(crate) struct CertResolver;
impl CertResolver {
#[expect(clippy::unused_self, reason = "still WIP")]
pub(crate) fn resolve(&self, _client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
None
}
}
}
pub(crate) fn can_handle(_request: &Request) -> bool {
false
}
pub(crate) async fn ignore_me(arbiter: Arbiter) -> Result<()> {
arbiter.shutdown().await;
Ok(())
}
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<()> {
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!("{}::ignore_me", module_path!(),))
.spawn(ignore_me(arbiter))?;
Ok(())
}

448
src/server/core.rs Normal file
View File

@@ -0,0 +1,448 @@
use std::sync::{Arc, LazyLock, atomic::Ordering};
use axum::{
Extension, Router,
body::Body,
extract::{Request, State},
http::{
HeaderName, HeaderValue, StatusCode, Uri,
header::{ACCEPT, CONTENT_TYPE, HOST, LOCATION, RETRY_AFTER},
},
middleware::{Next, from_fn},
response::{IntoResponse, Response},
routing::any,
};
use http_body_util::BodyExt as _;
use serde_json::json;
use crate::{
axum::{
accept::tls::TlsState,
error::Result,
extract::{client_ip::ClientIp, host::Host, scheme::Scheme, trusted_proxy::TrustedProxy},
router::wrap_router,
},
config, db,
server::{
GUNICORN_READY, Server,
core::websockets::{handle_websocket_upgrade, is_websocket_upgrade},
},
};
static STARTUP_RESPONSE_JSON: LazyLock<Response<String>> = LazyLock::new(|| {
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header(RETRY_AFTER, "5")
.header(CONTENT_TYPE, "application/json")
.body(
json!({
"error": "authentik starting",
})
.to_string(),
)
.expect("infallible")
});
static STARTUP_RESPONSE_HTML: LazyLock<Response<String>> = LazyLock::new(|| {
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header(CONTENT_TYPE, "text/html")
.body(include_str!("../../web/dist/standalone/loading/startup.html").to_owned())
.expect("infallible")
});
static STARTUP_RESPONSE_PLAIN: LazyLock<Response<String>> = LazyLock::new(|| {
Response::builder()
.status(StatusCode::SERVICE_UNAVAILABLE)
.header(CONTENT_TYPE, "text/plain")
.body("authentik starting".to_owned())
.expect("infallible")
});
const SERVER: HeaderName = HeaderName::from_static("server");
const X_FORWARDED_CLIENT_CERT: HeaderName = HeaderName::from_static("x-forwarded-client-cert");
const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
const X_FORWARDED_PROTO: HeaderName = HeaderName::from_static("x-forwarded-proto");
const X_POWERED_BY: HeaderName = HeaderName::from_static("x-powered-by");
const FORWARD_ALWAYS_REMOVED_HEADERS: [HeaderName; 7] = [
HeaderName::from_static("forwarded"),
HeaderName::from_static("host"),
X_FORWARDED_FOR,
HeaderName::from_static("x-forwarded-host"),
X_FORWARDED_PROTO,
HeaderName::from_static("x-forwarded-scheme"),
HeaderName::from_static("x-real-ip"),
];
const FORWARD_REMOVED_HEADERS_IF_UNTRUSTED: [HeaderName; 3] = [
HeaderName::from_static("ssl-client-cert"), // nginx-ingress
HeaderName::from_static("x-forwarded-tls-client-cert"), // traefik
X_FORWARDED_CLIENT_CERT, // envoy
];
fn startup_response(accept_header: &str) -> Response {
let response = if accept_header.contains("application/json") {
STARTUP_RESPONSE_JSON.clone()
} else if accept_header.contains("text/html") {
STARTUP_RESPONSE_HTML.clone()
} else {
STARTUP_RESPONSE_PLAIN.clone()
};
let (parts, body) = response.into_parts();
Response::from_parts(parts, body.into())
}
async fn forward_request(
ClientIp(client_ip): ClientIp,
Host(host): Host,
Scheme(scheme): Scheme,
State(server): State<Arc<Server>>,
TrustedProxy(trusted_proxy): TrustedProxy,
tls_state: Option<Extension<TlsState>>,
mut request: Request,
) -> Result<Response> {
let accept_header = request
.headers()
.get(ACCEPT)
.map(|v| v.to_str().unwrap_or_default().to_owned())
.unwrap_or_default();
if !GUNICORN_READY.load(Ordering::Relaxed) {
return Ok(startup_response(&accept_header));
}
let uri = Uri::builder()
.scheme("http")
.authority("localhost:8000")
.path_and_query(
request
.uri()
.path_and_query()
.map(|x| x.as_str())
.unwrap_or_default(),
)
.build()?;
*request.uri_mut() = uri;
for header_name in FORWARD_ALWAYS_REMOVED_HEADERS {
request.headers_mut().remove(header_name);
}
if !trusted_proxy {
for header_name in FORWARD_REMOVED_HEADERS_IF_UNTRUSTED {
request.headers_mut().remove(header_name);
}
}
request.headers_mut().insert(
X_FORWARDED_FOR,
HeaderValue::from_str(&client_ip.to_string())?,
);
request
.headers_mut()
.insert(HOST, HeaderValue::from_str(&host)?);
request
.headers_mut()
.insert(X_FORWARDED_PROTO, HeaderValue::from_str(scheme.as_ref())?);
if is_websocket_upgrade(request.headers()) {
return handle_websocket_upgrade(request, server).await;
}
if let Some(tls_state) = tls_state
&& let Some(peer_certificates) = &tls_state.peer_certificates
{
let xfcc = peer_certificates
.iter()
.map(|cert| {
let pem_encoded = pem::encode(&pem::Pem::new("CERTIFICATE", cert.as_ref()));
let url_encoded: String =
url::form_urlencoded::byte_serialize(pem_encoded.as_bytes()).collect();
format!("Cert={url_encoded}")
})
.collect::<Vec<_>>()
.join(",");
request
.headers_mut()
.insert("X_FORWARDED_CLIENT_CERT", HeaderValue::from_str(&xfcc)?);
}
match server.client.request(request).await {
Ok(res) => {
let (parts, body) = res.into_parts();
Ok(Response::from_parts(
parts,
Body::from_stream(body.into_data_stream()),
))
}
Err(_) => Ok(startup_response(&accept_header)),
}
}
fn build_gunicorn_router(server: Arc<Server>) -> Router {
wrap_router(
Router::new().fallback(forward_request).with_state(server),
config::get().debug, // enable tracing only in debug mode
)
}
async fn powered_by_middleware(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
response.headers_mut().remove(SERVER);
response
.headers_mut()
.insert(X_POWERED_BY, HeaderValue::from_static("authentik"));
response
}
async fn health_ready(State(server): State<Arc<Server>>) -> impl IntoResponse {
#[expect(clippy::if_same_then_else, reason = "For easier reading")]
if !server.is_alive().await {
StatusCode::SERVICE_UNAVAILABLE
} else if sqlx::query("SELECT 1").execute(db::get()).await.is_err() {
StatusCode::SERVICE_UNAVAILABLE
} else if let Some(workers) = server.workers.load_full()
&& !workers.are_alive().await
{
StatusCode::SERVICE_UNAVAILABLE
} else {
let req = Request::builder()
.method("GET")
.uri("http://localhost:8000/-/health/ready/")
.header(HOST, "localhost")
.body(Body::from(""));
if let Ok(req) = req
&& let Ok(res) = server.client.request(req).await
{
res.status()
} else {
StatusCode::SERVICE_UNAVAILABLE
}
}
}
pub(super) fn build_router(server: Arc<Server>) -> Router {
let router = wrap_router(
Router::new()
.route("/-/metrics/", any((StatusCode::NOT_FOUND, "not found")))
.route("/-/health/ready/", any(health_ready))
.with_state(Arc::clone(&server))
.merge(super::r#static::build_router()),
true,
)
.merge(build_gunicorn_router(server))
.layer(from_fn(powered_by_middleware));
let path = &config::get().web.path;
if config::get().web.path == "/" {
router
} else {
Router::new()
.route(
"/",
any(
async || match HeaderValue::try_from(&config::get().web.path) {
Ok(location) => (StatusCode::FOUND, [(LOCATION, location)]).into_response(),
Err(err) => {
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string()).into_response()
}
},
),
)
.nest(path, router)
}
}
mod websockets {
use std::sync::Arc;
use axum::{
body::Body,
extract::Request,
http::{
HeaderMap, HeaderValue, StatusCode,
header::{
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
},
},
response::{IntoResponse as _, Response},
};
use futures::{SinkExt as _, StreamExt as _};
use hyper_util::rt::TokioIo;
use tokio::{net::UnixStream, sync::mpsc};
use tokio_tungstenite::{
WebSocketStream, client_async,
tungstenite::{Message, handshake::derive_accept_key, protocol::Role},
};
use tracing::{debug, trace, warn};
use crate::{
axum::error::{AppError, Result},
server::Server,
};
pub(super) fn is_websocket_upgrade(headers: &HeaderMap<HeaderValue>) -> bool {
let has_upgrade = headers
.get(UPGRADE)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.eq_ignore_ascii_case("websocket"));
let has_connection = headers
.get(CONNECTION)
.and_then(|v| v.to_str().ok())
.is_some_and(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("upgrade"))
});
let has_websocket_key = headers.contains_key(SEC_WEBSOCKET_KEY);
let has_websocket_version = headers.contains_key(SEC_WEBSOCKET_VERSION);
has_upgrade && has_connection && has_websocket_key && has_websocket_version
}
pub(super) async fn handle_websocket_upgrade(
request: Request,
server: Arc<Server>,
) -> Result<Response> {
let Some(ws_key) = request
.headers()
.get(SEC_WEBSOCKET_KEY)
.and_then(|key| key.to_str().ok())
else {
return Ok((StatusCode::BAD_REQUEST, "").into_response());
};
let ws_accept = derive_accept_key(ws_key.as_bytes());
let path_q = request
.uri()
.path_and_query()
.map(|x| x.as_str())
.unwrap_or_default();
let uri = format!("ws://localhost:8000{path_q}");
let mut ws_request =
tokio_tungstenite::tungstenite::handshake::client::Request::builder().uri(uri);
for (k, v) in request.headers() {
ws_request = ws_request.header(k.as_str(), v);
}
let ws_request = ws_request.body(())?;
let response = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(UPGRADE, "websocket")
.header(CONNECTION, "upgrade")
.header(SEC_WEBSOCKET_ACCEPT, ws_accept)
.body(Body::empty())?;
tokio::spawn(async move {
if let Err(err) = handle_websocket_connection(request, server, ws_request).await {
warn!("WebSocket connection error: {}", err.0);
}
});
Ok(response)
}
async fn handle_websocket_connection(
request: Request,
server: Arc<Server>,
ws_request: tokio_tungstenite::tungstenite::handshake::client::Request,
) -> Result<()> {
let upgraded = hyper::upgrade::on(request).await?;
let io = TokioIo::new(upgraded);
let client_ws = WebSocketStream::from_raw_socket(io, Role::Server, None).await;
let upstream_ws = {
let stream = UnixStream::connect(&server.socket_path).await?;
let (ws_stream, _) = client_async(ws_request, stream).await?;
ws_stream
};
let (mut client_sender, mut client_receiver) = client_ws.split();
let (mut upstream_sender, mut upstream_receiver) = upstream_ws.split();
let (close_tx, mut close_rx) = mpsc::channel::<()>(1);
let close_tx_upstream = close_tx.clone();
let client_to_upstream = tokio::spawn(async move {
let mut client_closed = false;
while let Some(msg) = client_receiver.next().await {
let msg = msg?;
match msg {
Message::Close(_) => {
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
let _ = close_tx.send(()).await;
client_closed = true;
break;
}
}
msg @ (Message::Binary(_)
| Message::Text(_)
| Message::Ping(_)
| Message::Pong(_)) => {
if !client_closed {
upstream_sender.send(msg).await?;
}
}
Message::Frame(_) => {}
}
}
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
let _ = close_tx.send(()).await;
}
Ok::<_, AppError>(())
});
let upstream_to_client = tokio::spawn(async move {
let mut upstream_closed = false;
while let Some(msg) = upstream_receiver.next().await {
let msg = msg?;
match msg {
Message::Close(_) => {
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
let _ = close_tx_upstream.send(()).await;
upstream_closed = true;
break;
}
}
msg @ (Message::Binary(_)
| Message::Text(_)
| Message::Ping(_)
| Message::Pong(_)) => {
if !upstream_closed {
client_sender.send(msg).await?;
}
}
Message::Frame(_) => {}
}
}
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
let _ = close_tx_upstream.send(()).await;
}
Ok::<_, AppError>(())
});
tokio::select! {
_ = close_rx.recv() => {
trace!("WebSocket connection closed gracefully");
},
res = client_to_upstream => {
if let Err(err) = res {
debug!("Client to upstream task failed: {:?}", err);
}
}
res = upstream_to_client => {
if let Err(err) = res {
debug!("Upstream to client task failed: {:?}", err);
}
}
}
Ok(())
}
}

255
src/server/mod.rs Normal file
View File

@@ -0,0 +1,255 @@
use std::{
env::temp_dir,
os::unix,
path::PathBuf,
process::Stdio,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use arc_swap::ArcSwapOption;
use argh::FromArgs;
use axum::{Router, body::Body, extract::Request, http::status::StatusCode, routing::any};
use eyre::{Result, eyre};
use hyper_unix_socket::UnixSocketConnector;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use nix::{
sys::signal::{Signal, kill},
unistd::Pid,
};
use tokio::{
net::UnixStream,
process::{Child, Command},
signal::unix::SignalKind,
sync::{Mutex, broadcast::error::RecvError},
time::Instant,
};
use tower::ServiceExt as _;
use tower_http::timeout::TimeoutLayer;
use tracing::{info, trace, warn};
use crate::{
arbiter::{Arbiter, Tasks},
axum::server,
config,
worker::Workers,
};
pub(super) static GUNICORN_READY: AtomicBool = AtomicBool::new(false);
pub(crate) mod core;
mod r#static;
mod tls;
#[derive(Debug, Default, FromArgs, PartialEq)]
/// Run the authentik server.
#[argh(subcommand, name = "server")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(super) struct Cli {}
pub(crate) struct Server {
gunicorn: Mutex<Child>,
socket_path: PathBuf,
pub(crate) client: Client<UnixSocketConnector<PathBuf>, Body>,
pub(crate) workers: ArcSwapOption<Workers>,
}
impl Server {
fn new(socket_path: PathBuf) -> Result<Self> {
info!("starting gunicorn");
let gunicorn = Command::new("gunicorn")
.args([
"--bind",
&format!("unix://{}", socket_path.display()),
"-c",
"./lifecycle/gunicorn.conf.py",
"authentik.root.asgi:application",
])
.kill_on_drop(true)
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.spawn()?;
let client = Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(60))
.set_host(false)
.build(UnixSocketConnector::new(socket_path.clone()));
Ok(Self {
gunicorn: Mutex::new(gunicorn),
socket_path,
client,
workers: ArcSwapOption::empty(),
})
}
async fn shutdown(&self, signal: Signal) -> Result<()> {
trace!(
signal = signal.as_str(),
"sending shutdown signal to gunicorn"
);
let mut gunicorn = self.gunicorn.lock().await;
if let Some(id) = gunicorn.id() {
kill(Pid::from_raw(id.cast_signed()), signal)?;
}
gunicorn.wait().await?;
drop(gunicorn);
Ok(())
}
async fn graceful_shutdown(&self) -> Result<()> {
info!("gracefully shutting down gunicorn");
self.shutdown(Signal::SIGTERM).await
}
async fn fast_shutdown(&self) -> Result<()> {
info!("immediately shutting down gunicorn");
self.shutdown(Signal::SIGINT).await
}
async fn is_alive(&self) -> bool {
let try_wait = self.gunicorn.lock().await.try_wait();
match try_wait {
Ok(Some(code)) => {
warn!("gunicorn has exited with status {code}");
false
}
Ok(None) => true,
Err(err) => {
warn!("failed to check the status of gunicorn process, ignoring: {err}");
true
}
}
}
async fn is_socket_ready(&self) -> bool {
let result = UnixStream::connect(&self.socket_path).await;
trace!("checking if gunicorn is ready: {result:?}");
result.is_ok()
}
}
async fn watch_server(arbiter: Arbiter, server: Arc<Server>) -> Result<()> {
info!("starting server watcher");
let mut signals_rx = arbiter.signals_subscribe();
loop {
tokio::select! {
signal = signals_rx.recv() => {
match signal {
Ok(signal) => {
if signal == SignalKind::user_defined1() {
info!("gunicorn notified us ready, marked ready for operation");
GUNICORN_READY.store(true, Ordering::Relaxed);
arbiter.mark_gunicorn_ready();
}
},
Err(RecvError::Lagged(_)) => {},
Err(RecvError::Closed) => {
warn!("error receiving signals");
return Err(RecvError::Closed.into());
}
}
},
() = tokio::time::sleep(Duration::from_secs(1)), if !GUNICORN_READY.load(Ordering::Relaxed) => {
// On some platforms the SIGUSR1 can be missed.
// Fall back to probing the gunicorn unix socket and mark ready once it accepts connections.
if server.is_socket_ready().await {
info!("gunicorn socket is accepting connections, marked ready for operation");
GUNICORN_READY.store(true, Ordering::Relaxed);
arbiter.mark_gunicorn_ready();
}
},
() = tokio::time::sleep(Duration::from_secs(5)) => {
if !server.is_alive().await {
return Err(eyre!("gunicorn has exited unexpectedly"));
}
},
() = arbiter.fast_shutdown() => {
server.fast_shutdown().await?;
return Ok(());
},
() = arbiter.graceful_shutdown() => {
server.graceful_shutdown().await?;
return Ok(());
},
}
}
}
fn build_router(server: Arc<Server>) -> Router {
let core_router = core::build_router(server);
let proxy_router: Option<Router> = None;
let config = config::get();
let timeout = durstr::parse(&config.web.timeout_http_read_header)
.expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_read).expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_write).expect("Invalid duration in http timeout")
+ durstr::parse(&config.web.timeout_http_idle).expect("Invalid duration in http timeout");
let timeout_layer = TimeoutLayer::with_status_code(StatusCode::REQUEST_TIMEOUT, timeout);
Router::new()
.fallback(any(async |request: Request<Body>| {
metrics::describe_histogram!(
"authentik_main_request_duration",
metrics::Unit::Seconds,
"API request latencies in seconds"
);
let now = Instant::now();
if let Some(proxy_router) = proxy_router
&& crate::proxy::can_handle(&request)
{
let res = proxy_router.oneshot(request).await;
metrics::histogram!("authentik_main_request_duration", "dest" => "embedded_outpost")
.record(now.elapsed());
res
} else {
let res = core_router.oneshot(request).await;
metrics::histogram!("authentik_main_request_duration", "dest" => "core")
.record(now.elapsed());
res
}
}))
.layer(timeout_layer)
}
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
let config = config::get();
let arbiter = tasks.arbiter();
let server = Arc::new(Server::new(temp_dir().join("authentik-gunicorn.sock"))?);
tasks
.build_task()
.name(&format!("{}::watch_server", module_path!()))
.spawn(watch_server(arbiter.clone(), Arc::clone(&server)))?;
let router = build_router(Arc::clone(&server));
for addr in config.listen.http.iter().copied() {
server::start_plain(tasks, "server", router.clone(), addr)?;
}
let tls_config = tls::make_initial_tls_config()?;
for addr in config.listen.https.iter().copied() {
server::start_tls(tasks, "tls", router.clone(), addr, tls_config.clone())?;
}
tasks
.build_task()
.name(&format!("{}::tls::watch_tls_config", module_path!(),))
.spawn(tls::watch_tls_config(arbiter, tls_config))?;
server::start_unix(
tasks,
"server",
router,
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
)?;
Ok(server)
}

255
src/server/static.rs Normal file
View File

@@ -0,0 +1,255 @@
use std::fmt::Write as _;
use aws_lc_rs::digest;
use axum::{
Router,
extract::{Query, Request, State},
http::{
HeaderValue, StatusCode,
header::{CACHE_CONTROL, CONTENT_SECURITY_POLICY, VARY},
},
middleware::{self, Next},
response::{IntoResponse as _, Response},
routing::any,
};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
use percent_encoding::percent_decode_str;
use serde::Deserialize;
use time::OffsetDateTime;
use tower_http::{
compression::{CompressionLayer, predicate::SizeAbove},
services::fs::ServeDir,
};
use crate::config;
#[derive(Debug, Deserialize)]
struct StorageClaims {
exp: Option<i64>,
nbf: Option<i64>,
path: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StorageTokenQuery {
token: Option<String>,
}
fn is_storage_token_valid(usage: &str, secret_key: &str, request: &Request) -> bool {
// Use typed query parsing so `token` is percent-decoded before JWT parsing.
let token_string = match Query::<StorageTokenQuery>::try_from_uri(request.uri()) {
Ok(query) => match query.0.token {
Some(token) if !token.is_empty() => token,
_ => return false,
},
Err(_) => return false,
};
let Ok(token_header) = decode_header(&token_string) else {
return false;
};
// Must match what we use in authentik/admin/files/backends/file.py
if token_header.alg != Algorithm::HS256 {
return false;
}
// Derive a per-usage key so media and reports tokens are not interchangeable.
let key = format!("{secret_key}:{usage}");
let key_digest = digest::digest(&digest::SHA256, key.as_bytes());
let key_hex_digest = key_digest
.as_ref()
.iter()
.fold(String::new(), |mut acc, b| {
let _ = write!(acc, "{b:02x}");
acc
});
let mut validation = Validation::new(token_header.alg);
validation.validate_exp = false;
validation.validate_nbf = false;
validation.validate_aud = false;
validation.required_spec_claims.clear();
let claims = match decode::<StorageClaims>(
&token_string,
&DecodingKey::from_secret(key_hex_digest.as_bytes()),
&validation,
) {
Ok(token) => token.claims,
Err(_) => return false,
};
let now = OffsetDateTime::now_utc().unix_timestamp();
if claims.exp.unwrap_or(0) < now {
return false;
}
if claims.nbf.unwrap_or(now + 1) > now {
return false;
}
let Some(claim_path) = claims.path else {
return false;
};
// Decode path before comparison so encoded URL segments cannot bypass path binding.
let Ok(request_path) = percent_decode_str(request.uri().path()).decode_utf8() else {
return false;
};
let request_path = request_path.trim_start_matches('/');
let expected_path = format!("{usage}/{request_path}");
if claim_path != expected_path {
return false;
}
true
}
#[derive(Clone)]
struct StorageMiddlewareConfig {
usage: &'static str,
set_csp_header: bool,
}
async fn storage_middleware(
State(config): State<StorageMiddlewareConfig>,
request: Request,
next: Next,
) -> Response {
if !is_storage_token_valid(config.usage, &config::get().secret_key, &request) {
return (StatusCode::NOT_FOUND, "404 page not found\n").into_response();
}
let mut response = next.run(request).await;
if config.set_csp_header {
// Since media is user-controlled, better be safe
response.headers_mut().insert(
CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'none'; style-src 'unsafe-inline'; sandbox"),
);
}
response
}
async fn static_header_middleware(request: Request, next: Next) -> Response {
let mut response = next.run(request).await;
response.headers_mut().insert(
CACHE_CONTROL,
HeaderValue::from_static("public, no-transform"),
);
response.headers_mut().insert(
"X-authentik-version",
HeaderValue::from_static(env!("CARGO_PKG_VERSION")),
);
response
.headers_mut()
.insert(VARY, HeaderValue::from_static("X-authentik-version, Etag"));
// TODO: etag
response
}
pub(crate) fn build_router() -> Router {
let config = config::get();
let mut router = Router::new().layer(middleware::from_fn(static_header_middleware));
let dist_fs = ServeDir::new("./web/dist/").append_index_html_on_directories(false);
let static_fs = ServeDir::new("./web/authentik/").append_index_html_on_directories(false);
router = router.nest_service("/static/dist/", dist_fs.clone());
router = router.nest_service("/static/authentik/", static_fs);
router = router.nest_service("/if/flow/{flow_slug}/assets/", dist_fs.clone());
router = router.nest_service("/if/admin/assets/", dist_fs.clone());
router = router.nest_service("/if/user/assets/", dist_fs.clone());
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs);
let default_backend = &config.storage.backend;
let media_backend = config
.storage
.media
.clone()
.unwrap_or_default()
.backend
.unwrap_or_else(|| default_backend.clone());
let reports_backend = config
.storage
.reports
.clone()
.unwrap_or_default()
.backend
.unwrap_or_else(|| default_backend.clone());
let default_path = &config.storage.file.path;
if media_backend == "file" {
let media_path = config
.storage
.media
.clone()
.unwrap_or_default()
.file
.unwrap_or_default()
.path
.unwrap_or_else(|| default_path.clone())
.join("media");
let media_fs = ServeDir::new(media_path).append_index_html_on_directories(false);
let media_router =
Router::new()
.fallback_service(media_fs)
.layer(middleware::from_fn_with_state(
StorageMiddlewareConfig {
usage: "media",
set_csp_header: true,
},
storage_middleware,
));
router = router.nest("/files/media/", media_router);
}
if reports_backend == "file" {
let reports_path = config
.storage
.reports
.clone()
.unwrap_or_default()
.file
.unwrap_or_default()
.path
.unwrap_or_else(|| default_path.clone())
.join("reports");
let reports_fs = ServeDir::new(reports_path).append_index_html_on_directories(false);
let reports_router =
Router::new()
.fallback_service(reports_fs)
.layer(middleware::from_fn_with_state(
StorageMiddlewareConfig {
usage: "reports",
set_csp_header: false,
},
storage_middleware,
));
router = router.nest("/files/reports/", reports_router);
}
router = router.route(
"/robots.txt",
any(async || include_str!("../../web/robots.txt")),
);
router = router.route(
"/.well-known/security.txt",
any(async || include_str!("../../web/security.txt")),
);
router = router.layer(middleware::from_fn(static_header_middleware));
router = router.layer(CompressionLayer::new().compress_when(SizeAbove::new(32)));
router
}

147
src/server/tls.rs Normal file
View File

@@ -0,0 +1,147 @@
use std::{sync::Arc, time::Duration};
use axum_server::tls_rustls::RustlsConfig;
use eyre::Result;
use rcgen::PKCS_ECDSA_P256_SHA256;
use rustls::{
ServerConfig,
server::{ClientHello, ResolvesServerCert, WebPkiClientVerifier},
sign::CertifiedKey,
};
use tracing::{debug, warn};
use crate::{arbiter::Arbiter, brands, proxy};
pub(super) fn make_initial_tls_config() -> Result<RustlsConfig> {
let (cert, keypair) = self_signed::generate(&PKCS_ECDSA_P256_SHA256)?;
Ok(RustlsConfig::from_config(Arc::new(
ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert.into()], keypair.into())?,
)))
}
async fn make_tls_config(fallback: Arc<CertifiedKey>) -> Result<ServerConfig> {
let (core_resolver, roots) = brands::tls::make_cert_managers().await?;
let cert_resolver = CertResolver {
core_resolver,
proxy_resolver: None,
fallback,
};
let client_cert_verifier = WebPkiClientVerifier::builder(Arc::new(roots))
.allow_unauthenticated()
.build()?;
Ok(ServerConfig::builder()
.with_client_cert_verifier(client_cert_verifier)
.with_cert_resolver(Arc::new(cert_resolver)))
}
pub(super) async fn watch_tls_config(arbiter: Arbiter, config: RustlsConfig) -> Result<()> {
tokio::select! {
() = arbiter.gunicorn_ready() => {},
() = arbiter.shutdown() => return Ok(()),
}
let fallback = Arc::new(self_signed::generate_certifiedkey(&PKCS_ECDSA_P256_SHA256)?);
loop {
match make_tls_config(Arc::clone(&fallback)).await {
Ok(new_config) => {
config.reload_from_config(Arc::new(new_config));
debug!("reloaded tls config");
}
Err(err) => {
warn!("error while reloading tls config {err:?}");
}
}
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(60)) => {},
() = arbiter.shutdown() => return Ok(()),
}
}
}
#[derive(Debug)]
struct CertResolver {
core_resolver: brands::tls::CertResolver,
proxy_resolver: Option<proxy::tls::CertResolver>,
fallback: Arc<CertifiedKey>,
}
#[expect(
clippy::missing_trait_methods,
reason = "the provided methods are sensible enough"
)]
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if client_hello.server_name().is_none() {
Some(Arc::clone(&self.fallback))
} else if let Some(resolver) = &self.proxy_resolver
&& let Some(cert) = resolver.resolve(&client_hello)
{
Some(cert)
} else if let Some(cert) = self.core_resolver.resolve(&client_hello) {
Some(cert)
} else {
Some(Arc::clone(&self.fallback))
}
}
}
mod self_signed {
use eyre::Result;
use rcgen::{
Certificate, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
KeyPair, KeyUsagePurpose, SanType, SignatureAlgorithm,
};
use rustls::{
crypto::aws_lc_rs::sign::any_supported_type,
pki_types::{CertificateDer, PrivateKeyDer},
sign::CertifiedKey,
};
use time::{Duration, OffsetDateTime};
pub(super) fn generate(alg: &'static SignatureAlgorithm) -> Result<(Certificate, KeyPair)> {
let signing_key = KeyPair::generate_for(alg)?;
let mut params = CertificateParams::default();
params.not_before = OffsetDateTime::now_utc();
params.not_after = OffsetDateTime::now_utc() + Duration::days(365);
params.distinguished_name = {
let mut dn = DistinguishedName::new();
dn.push(DnType::OrganizationName, "authentik");
dn.push(DnType::CommonName, "authentik default certificate");
dn
};
params.subject_alt_names = vec![SanType::DnsName("*".try_into()?)];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
let cert = params.self_signed(&signing_key)?;
Ok((cert, signing_key))
}
pub(super) fn generate_certifiedkey(alg: &'static SignatureAlgorithm) -> Result<CertifiedKey> {
let (cert, keypair) = generate(alg)?;
let cert_der = cert.der().to_vec();
let key_der = keypair.serialize_der();
let private_key =
PrivateKeyDer::try_from(key_der).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
let signing_key =
any_supported_type(&private_key).map_err(|_| rcgen::Error::CouldNotParseKeyPair)?;
Ok(CertifiedKey::new(
vec![CertificateDer::from(cert_der)],
signing_key,
))
}
}

1
src/tokio/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub(crate) mod proxy_protocol;

View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2026 Authentik Security Inc.
Copyright (c) 2023 Tibor Djurica Potpara
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1 @@
This is a fork of https://github.com/tibordp/proxy-header/, with the sync code removed, the encoding code removed, and the ability to make the PROXY protocol optional.

View File

@@ -0,0 +1,634 @@
use std::{borrow::Cow, fmt, net::SocketAddr, str::from_utf8};
use thiserror::Error;
use tracing::instrument;
/// Protocol type
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub(crate) enum Protocol {
/// Stream protocol (TCP)
Stream,
/// Datagram protocol (UDP)
Datagram,
}
/// Address information from a PROXY protocol header
#[derive(Debug, PartialEq, Eq, Clone, Hash)]
pub(crate) struct Address {
/// Protocol type
pub(crate) protocol: Protocol,
/// Source address (of the actual client)
pub(crate) source: SocketAddr,
/// Destination address (of the proxy)
pub(crate) destination: SocketAddr,
}
macro_rules! tlv {
($self:expr, $kind:ident) => {{
$self.tlvs().find_map(|f| match f {
Ok(Tlv::$kind(v)) => Some(v),
_ => None,
})
}};
}
macro_rules! tlv_borrowed {
($self:expr, $kind:ident) => {{
$self.tlvs().find_map(|f| match f {
Ok(Tlv::$kind(v)) => match v {
// It is more ergonomic to return the borrowed value directly rather
// than it wrapped in a `Cow::Borrowed`. We know that tlvs always borrows
// so we can safely unwrap the `Cow::Borrowed` and return the borrowed value.
Cow::Owned(_) => unreachable!(),
Cow::Borrowed(v) => Some(v),
},
_ => None,
})
}};
}
/// Iterator over PROXY protocol TLV fields
pub(crate) struct Tlvs<'a> {
buf: &'a [u8],
}
#[expect(
clippy::missing_trait_methods,
reason = "we don't need to implement the other methods here"
)]
impl<'a> Iterator for Tlvs<'a> {
type Item = Result<Tlv<'a>, ProxyProtocolError>;
fn next(&mut self) -> Option<Self::Item> {
if self.buf.is_empty() {
return None;
}
let kind = self.buf[0];
match self
.buf
.get(1..3)
.map(|s| -> usize { u16::from_be_bytes(s.try_into().expect("infallible")).into() })
{
Some(u) if u + 3 <= self.buf.len() => {
let (ret, new) = self.buf.split_at(3 + u);
self.buf = new;
Some(Tlv::decode(kind, &ret[3..]))
}
_ => {
// Malformed TLV, cannot continue
self.buf = &[];
Some(Err(ProxyProtocolError::Invalid))
}
}
}
}
/// Typed TLV field
///
/// Represents the currently known types of TLV fields from the PROXY protocol specification.
/// Non-recognized TLV fields are represented as [`Tlv::Custom`].
#[non_exhaustive]
#[derive(Debug, PartialEq, Eq, Clone)]
pub(crate) enum Tlv<'a> {
/// Application-Layer Protocol Negotiation (ALPN). It is a byte sequence defining the upper
/// layer protocol in use over the connection. The most common use case will be to pass the
/// exact copy of the ALPN extension of the Transport Layer Security (TLS) protocol as defined
/// by RFC7301.
Alpn(Cow<'a, [u8]>),
/// Contains the host name value passed by the client, as an UTF-8 encoded string. In case of
/// TLS being used on the client connection, this is the exact copy of the `server_name`
/// extension as defined by RFC3546, section 3.1, often referred to as SNI. There are probably
/// other situations where an authority can be mentionned on a connection without TLS being
/// involved at all.
Authority(Cow<'a, str>),
/// The value of the type `PP2_TYPE_CRC32C` is a 32-bit number storing the `CRC32c` checksum of
/// the PROXY protocol header.
///
/// When the checksum is supported by the sender after constructing the header the sender MUST:
///
/// - initialize the checksum field to '0's.
///
/// - calculate the `CRC32c` checksum of the PROXY header as described in RFC4960, Appendix B.
///
/// - put the resultant value into the checksum field, and leave the rest of the bits unchanged.
///
/// If the checksum is provided as part of the PROXY header and the checksum functionality is
/// supported by the receiver, the receiver MUST:
///
/// - store the received `CRC32c` checksum value aside.
///
/// - replace the 32 bits of the checksum field in the received PROXY header with all '0's and
/// calculate a `CRC32c` checksum value of the whole PROXY header.
///
/// - verify that the calculated `CRC32c` checksum is the same as the received `CRC32c`
/// checksum. If it is not, the receiver MUST treat the TCP connection providing the header as
/// invalid.
///
/// The default procedure for handling an invalid TCP connection is to abort it.
Crc32c(u32),
/// The TLV of this type should be ignored when parsed. The value is zero or more bytes. Can be
/// used for data padding or alignment. Note that it can be used to align only by 3 or more
/// bytes because a TLV can not be smaller than that.
Noop(usize),
/// The value of the type `PP2_TYPE_UNIQUE_ID` is an opaque byte sequence of up to
/// 128 bytes generated by the upstream proxy that uniquely identifies the connection.
///
/// The unique ID can be used to easily correlate connections across multiple layers of
/// proxies, without needing to look up IP addresses and port numbers.
UniqueId(Cow<'a, [u8]>),
/// SSL (TLS) information
///
/// See [`SslInfo`] for more information.
Ssl(SslInfo<'a>),
/// The type `PP2_TYPE_NETNS` defines the value as the US-ASCII string representation of the
/// namespace's name.
Netns(Cow<'a, str>),
// The following can only appear as a sub-TLV of SslInfo
/// SSL/TLS version
SslVersion(Cow<'a, str>),
/// In all cases, the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3)
/// of the client certificate's Distinguished Name, is appended using the TLV format and the
/// type `PP2_SUBTYPE_SSL_CN`. E.g. "example.com".
SslCn(Cow<'a, str>),
/// The second level TLV `PP2_SUBTYPE_SSL_CIPHER` provides the US-ASCII string name of the used
/// cipher, for example "ECDHE-RSA-AES128-GCM-SHA256".
SslCipher(Cow<'a, str>),
/// The second level TLV `PP2_SUBTYPE_SSL_SIG_ALG` provides the US-ASCII string name of the
/// algorithm used to sign the certificate presented by the frontend when the incoming
/// connection was made over an SSL/TLS transport layer, for example "SHA256".
SslSigAlg(Cow<'a, str>),
/// The second level TLV `PP2_SUBTYPE_SSL_KEY_ALG` provides the US-ASCII string name of the
/// algorithm used to generate the key of the certificate presented by the frontend when the
/// incoming connection was made over an SSL/TLS transport layer, for example "RSA2048".
SslKeyAlg(Cow<'a, str>),
/// Unrecognized or custom TLV field
Custom(u8, Cow<'a, [u8]>),
}
impl<'a> Tlv<'a> {
fn decode(kind: u8, data: &'a [u8]) -> Result<Self, ProxyProtocolError> {
match kind {
0x01 => Ok(Self::Alpn(data.into())),
0x02 => Ok(Self::Authority(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x03 => Ok(Self::Crc32c(u32::from_be_bytes(
data.try_into().map_err(|_| ProxyProtocolError::Invalid)?,
))),
0x04 => Ok(Self::Noop(data.len())),
0x05 => Ok(Self::UniqueId(data.into())),
0x20 => Ok(Tlv::Ssl(SslInfo(
*data.first().ok_or(ProxyProtocolError::Invalid)?,
u32::from_be_bytes(
data.get(1..5)
.ok_or(ProxyProtocolError::Invalid)?
.try_into()
.map_err(|_| ProxyProtocolError::Invalid)?,
),
data.get(5..).ok_or(ProxyProtocolError::Invalid)?.into(),
))),
0x21 => Ok(Self::SslVersion(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x22 => Ok(Self::SslCn(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x23 => Ok(Self::SslCipher(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x24 => Ok(Self::SslSigAlg(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x25 => Ok(Self::SslKeyAlg(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x30 => Ok(Self::Netns(
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
t => Ok(Self::Custom(t, data.into())),
}
}
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn into_owned(self) -> Tlv<'static> {
match self {
Self::Alpn(v) => Tlv::Alpn(Cow::Owned(v.into_owned())),
Self::Authority(v) => Tlv::Authority(Cow::Owned(v.into_owned())),
Self::Crc32c(v) => Tlv::Crc32c(v),
Self::Noop(v) => Tlv::Noop(v),
Self::UniqueId(v) => Tlv::UniqueId(Cow::Owned(v.into_owned())),
Self::Ssl(v) => Tlv::Ssl(v.into_owned()),
Self::Netns(v) => Tlv::Netns(Cow::Owned(v.into_owned())),
Self::SslVersion(v) => Tlv::SslVersion(Cow::Owned(v.into_owned())),
Self::SslCn(v) => Tlv::SslCn(Cow::Owned(v.into_owned())),
Self::SslCipher(v) => Tlv::SslCipher(Cow::Owned(v.into_owned())),
Self::SslSigAlg(v) => Tlv::SslSigAlg(Cow::Owned(v.into_owned())),
Self::SslKeyAlg(v) => Tlv::SslKeyAlg(Cow::Owned(v.into_owned())),
Self::Custom(a, v) => Tlv::Custom(a, Cow::Owned(v.into_owned())),
}
}
}
/// SSL information from a PROXY protocol header
#[derive(PartialEq, Eq, Clone)]
pub(crate) struct SslInfo<'a>(u8, u32, Cow<'a, [u8]>);
impl SslInfo<'_> {
/// Client connected over SSL/TLS
///
/// The `PP2_CLIENT_SSL` flag indicates that the client connected over SSL/TLS. When this field
/// is present, the US-ASCII string representation of the TLS version is appended at the end of
/// the field in the TLV format using the type `PP2_SUBTYPE_SSL_VERSION`.
pub(crate) fn client_ssl(&self) -> bool {
self.0 & 0x01 != 0
}
/// Client certificate presented in the connection
///
/// `PP2_CLIENT_CERT_CONN` indicates that the client provided a certificate over the current
/// connection.
pub(crate) fn client_cert_conn(&self) -> bool {
self.0 & 0x02 != 0
}
/// Client certificate presented in the session
///
/// `PP2_CLIENT_CERT_SESS` indicates that the client provided a certificate at least once over
/// the TLS session this connection belongs to.
pub(crate) fn client_cert_sess(&self) -> bool {
self.0 & 0x04 != 0
}
/// Whether the certificate was verified
///
/// The verify field will be zero if the client presented a certificate and it was successfully
/// verified, and non-zero otherwise.
pub(crate) fn verify(&self) -> u32 {
self.1
}
/// Iterator over all TLV (type-length-value) fields
pub(crate) fn tlvs(&self) -> Tlvs<'_> {
Tlvs { buf: &self.2 }
}
// Convenience accessors for common TLVs
/// SSL version
///
/// See [`Tlv::SslVersion`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn version(&self) -> Option<&str> {
tlv_borrowed!(self, SslVersion)
}
/// SSL CN
///
/// See [`Tlv::SslCn`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn cn(&self) -> Option<&str> {
tlv_borrowed!(self, SslCn)
}
/// SSL cipher
///
/// See [`Tlv::SslCipher`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn cipher(&self) -> Option<&str> {
tlv_borrowed!(self, SslCipher)
}
/// SSL signature algorithm
///
/// See [`Tlv::SslSigAlg`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn sig_alg(&self) -> Option<&str> {
tlv_borrowed!(self, SslSigAlg)
}
/// SSL key algorithm
///
/// See [`Tlv::SslKeyAlg`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn key_alg(&self) -> Option<&str> {
tlv_borrowed!(self, SslKeyAlg)
}
/// Returns an owned version of this struct
pub(crate) fn into_owned(self) -> SslInfo<'static> {
SslInfo(self.0, self.1, Cow::Owned(self.2.into_owned()))
}
}
impl fmt::Debug for SslInfo<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslInfo")
.field("verify", &self.verify())
.field("client_ssl", &self.client_ssl())
.field("client_cert_conn", &self.client_cert_conn())
.field("client_cert_sess", &self.client_cert_sess())
.field("fields", &self.tlvs().collect::<Vec<_>>())
.finish()
}
}
/// A PROXY protocol header
#[derive(Default, PartialEq, Eq, Clone, Debug)]
pub(crate) struct Header<'a>(pub(super) Option<Address>, pub(super) Cow<'a, [u8]>);
impl<'a> Header<'a> {
/// Attempt to parse a PROXY protocol header from the given buffer
///
/// Returns the parsed header and the number of bytes consumed from the buffer. If the header
/// is incomplete, returns [`ProxyProtocolError::BufferTooShort`] so more data can be read from
/// the socket.
///
/// If the header is malformed or unsupported, returns [`ProxyProtocolError::Invalid`].
///
/// This function will borrow the buffer for the lifetime of the returned header. If
/// you need to keep the header around for longer than the buffer, use
/// [`Header::into_owned`].
#[instrument(skip_all)]
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), ProxyProtocolError> {
match buf.first() {
Some(b'P') => super::v1::decode(buf),
Some(b'\r') => super::v2::decode(buf),
None => Err(ProxyProtocolError::BufferTooShort),
_ => Err(ProxyProtocolError::Invalid),
}
}
/// Proxied address information
///
/// If `None`, this indicates so-called "local" mode, where the connection is not proxied.
/// This is usually the case when the connection is initiated by the proxy itself, e.g. for
/// health checks.
pub(crate) fn proxied_address(&self) -> Option<&Address> {
self.0.as_ref()
}
/// Iterator that yields all extension TLV (type-length-value) fields present in the header
///
/// See [`Tlv`] for more information on the different types of TLV fields.
pub(crate) fn tlvs(&self) -> Tlvs<'_> {
Tlvs { buf: &self.1 }
}
// Convenience accessors for common fields
/// Raw ALPN extension data
///
/// See [`Tlv::Alpn`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn alpn(&self) -> Option<&[u8]> {
tlv_borrowed!(self, Alpn)
}
/// Authority - typically the hostname of the client (SNI)
///
/// See [`Tlv::Authority`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn authority(&self) -> Option<&str> {
tlv_borrowed!(self, Authority)
}
/// `CRC32c` checksum of the address information
///
/// See [`Tlv::Crc32c`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn crc32c(&self) -> Option<u32> {
tlv!(self, Crc32c)
}
/// Unique ID of the connection
///
/// See [`Tlv::UniqueId`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn unique_id(&self) -> Option<&[u8]> {
tlv_borrowed!(self, UniqueId)
}
/// SSL information
///
/// See [`Tlv::Ssl`] for more information.
pub(crate) fn ssl(&self) -> Option<SslInfo<'_>> {
tlv!(self, Ssl)
}
/// Network namespace
///
/// See [`Tlv::Netns`] for more information.
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn netns(&self) -> Option<&str> {
tlv_borrowed!(self, Netns)
}
/// Returns an owned version of this struct
pub(crate) fn into_owned(self) -> Header<'static> {
Header(self.0, Cow::Owned(self.1.into_owned()))
}
}
#[derive(Debug, PartialEq, Eq, Error)]
pub(crate) enum ProxyProtocolError {
#[error("The buffer is too short to contain a complete PROXY protocol header")]
BufferTooShort,
#[error("The PROXY protocol header is malformed")]
Invalid,
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::*;
const V1_UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
const V1_TCPV4: &[u8] = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
const V1_TCPV6: &[u8] = b"PROXY TCP6 2001:db8::1 ::1 12345 443\r\n";
const V2_LOCAL: &[u8] =
b"\r\n\r\n\0\r\nQUIT\n \0\0\x0f\x03\0\x04\x88\x9d\xa1\xdf \0\x05\0\0\0\0\0";
const V2_TCPV4: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 12, 127, 0, 0, 1, 192, 168, 0, 1,
48, 57, 1, 187,
];
const V2_TCPV6: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 33, 0, 36, 32, 1, 13, 184, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 48, 57, 1, 187,
];
const V2_TCPV4_TLV: &[u8] = &[
13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 104, 127, 0, 0, 1, 192, 168, 0,
1, 48, 57, 1, 187, 3, 0, 4, 211, 153, 216, 216, 5, 0, 4, 49, 50, 51, 52, 32, 0, 75, 7, 0,
0, 0, 0, 33, 0, 7, 84, 76, 83, 118, 49, 46, 51, 34, 0, 9, 108, 111, 99, 97, 108, 104, 111,
115, 116, 37, 0, 7, 82, 83, 65, 52, 48, 57, 54, 36, 0, 10, 82, 83, 65, 45, 83, 72, 65, 50,
53, 54, 35, 0, 22, 84, 76, 83, 95, 65, 69, 83, 95, 50, 53, 54, 95, 71, 67, 77, 95, 83, 72,
65, 51, 56, 52,
];
#[test]
fn parse_proxy_header_too_short() {
for case in [
V1_TCPV4,
V1_TCPV6,
V1_UNKNOWN,
V2_TCPV4,
V2_TCPV6,
V2_TCPV4_TLV,
V2_LOCAL,
]
.iter()
{
for i in 0..case.len() {
assert!(matches!(
Header::parse(&case[..i]),
Err(ProxyProtocolError::BufferTooShort)
));
}
assert!(Header::parse(case).is_ok());
}
}
#[test]
fn test_parse_proxy_header_v1_unterminated() {
let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR";
assert!(matches!(
Header::parse(line),
Err(ProxyProtocolError::Invalid)
));
}
#[test]
fn test_parse_proxy_header_v1() {
let (res, consumed) = Header::parse(V1_TCPV4).expect("failed to parse");
assert_eq!(consumed, V1_TCPV4.len());
assert_eq!(
res.0,
Some(Address {
protocol: Protocol::Stream,
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
})
);
assert_eq!(res.1, vec![0; 0]);
let (res, consumed) = Header::parse(V1_TCPV6).expect("failed to parse");
assert_eq!(consumed, V1_TCPV6.len());
assert_eq!(
res.0,
Some(Address {
protocol: Protocol::Stream,
source: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
12345
),
destination: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
443
),
})
);
assert_eq!(res.1, vec![0; 0]);
}
#[test]
fn test_parse_proxy_header_v2() {
let (res, consumed) = Header::parse(V2_LOCAL).expect("failed to parse");
assert_eq!(consumed, V2_LOCAL.len());
assert_eq!(res.0, None);
let (res, consumed) = Header::parse(V2_TCPV4).expect("failed to parse");
assert_eq!(consumed, V2_TCPV4.len());
assert_eq!(
res.0,
Some(Address {
protocol: Protocol::Stream,
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
})
);
assert_eq!(res.1, vec![0; 0]);
let (res, consumed) = Header::parse(V2_TCPV6).expect("failed to parse");
assert_eq!(consumed, V2_TCPV6.len());
assert_eq!(
res.0,
Some(Address {
protocol: Protocol::Stream,
source: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1)),
12345
),
destination: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
443
),
})
);
assert_eq!(res.1, vec![0; 0]);
}
#[test]
fn test_parse_proxy_header_v2_with_tlvs() {
let (res, _) = Header::parse(V2_TCPV4_TLV).expect("failed to parse");
let mut fields = res.tlvs();
assert_eq!(fields.next(), Some(Ok(Tlv::Crc32c(0xd399_d8d8))));
assert_eq!(fields.next(), Some(Ok(Tlv::UniqueId(b"1234"[..].into()))));
let ssl = fields
.next()
.expect("next tlv missing")
.expect("tlv parsing failed");
let ssl = match ssl {
Tlv::Ssl(ssl) => ssl,
_ => panic!("expected SSL TLV"),
};
assert_eq!(ssl.verify(), 0);
assert!(ssl.client_ssl());
assert!(ssl.client_cert_conn());
assert!(ssl.client_cert_sess());
let mut f = ssl.tlvs();
assert_eq!(f.next(), Some(Ok(Tlv::SslVersion("TLSv1.3".into()))));
assert_eq!(f.next(), Some(Ok(Tlv::SslCn("localhost".into()))));
assert_eq!(f.next(), Some(Ok(Tlv::SslKeyAlg("RSA4096".into()))));
assert_eq!(f.next(), Some(Ok(Tlv::SslSigAlg("RSA-SHA256".into()))));
assert_eq!(
f.next(),
Some(Ok(Tlv::SslCipher("TLS_AES_256_GCM_SHA384".into())))
);
assert!(f.next().is_none());
assert!(fields.next().is_none());
}
}

View File

@@ -0,0 +1,243 @@
use std::{
cmp::min,
io,
io::IoSlice,
ops::Deref,
pin::Pin,
task::{Context, Poll},
};
use eyre::{Result, eyre};
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf};
use tracing::instrument;
use crate::tokio::proxy_protocol::header::{Header, ProxyProtocolError};
pub(crate) mod header;
mod utils;
mod v1;
mod v2;
// Length of the read buffer
const READ_BUFFER_LEN: usize = 536;
pin_project! {
pub struct ProxyProtocolStream<S> {
#[pin]
stream: S,
remaining: Vec<u8>,
header: Option<Header<'static>>,
}
}
impl<S> ProxyProtocolStream<S> {
pub(crate) fn header(&self) -> Option<&Header<'static>> {
self.header.as_ref()
}
fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
self.project().stream
}
#[expect(unused, reason = "Left here if we extract this to a public library")]
pub(crate) fn try_into_stream(self) -> Result<S> {
if self.remaining.is_empty() {
Ok(self.stream)
} else {
Err(eyre!(
"Cannot return inner stream because buffer is not empty"
))
}
}
}
impl<S> AsRef<S> for ProxyProtocolStream<S> {
fn as_ref(&self) -> &S {
&self.stream
}
}
impl<S> Deref for ProxyProtocolStream<S> {
type Target = S;
fn deref(&self) -> &Self::Target {
&self.stream
}
}
impl<S> ProxyProtocolStream<S>
where S: AsyncRead + Unpin
{
#[instrument(skip_all)]
pub(crate) async fn new(mut stream: S) -> Result<Self, io::Error> {
let mut remaining = Vec::with_capacity(READ_BUFFER_LEN);
loop {
let bytes_read = stream.read_buf(&mut remaining).await?;
if bytes_read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"end of stream",
));
}
match Header::parse(&remaining) {
Ok((header, consumed)) => {
let header = header.into_owned();
remaining.drain(..consumed);
return Ok(Self {
stream,
remaining,
header: Some(header),
});
}
Err(ProxyProtocolError::BufferTooShort) => {}
// Something went wrong parsing the PROXY protocol. We assume that we weren't meant
// to parse it, and that this is just a regular stream without the PROXY protocol.
Err(_) => {
return Ok(Self {
stream,
remaining,
header: None,
});
}
}
}
}
}
impl<S> AsyncRead for ProxyProtocolStream<S>
where S: AsyncRead
{
#[instrument(skip_all)]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.project();
if !this.remaining.is_empty() {
let to_copy = min(this.remaining.len(), buf.remaining());
buf.put_slice(&this.remaining[..to_copy]);
this.remaining.drain(..to_copy);
return Poll::Ready(Ok(()));
}
this.stream.poll_read(cx, buf)
}
}
impl<S> AsyncBufRead for ProxyProtocolStream<S>
where S: AsyncBufRead
{
#[instrument(skip_all)]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.project();
if !this.remaining.is_empty() {
return Poll::Ready(Ok(&this.remaining[..]));
}
this.stream.poll_fill_buf(cx)
}
#[instrument(skip_all)]
fn consume(self: Pin<&mut Self>, amt: usize) {
let this = self.project();
if this.remaining.is_empty() {
this.stream.consume(amt);
} else {
let len = this.remaining.len();
if amt <= len {
this.remaining.drain(..amt);
} else {
this.remaining.drain(..len);
this.stream.consume(amt - len);
}
}
}
}
impl<S> AsyncWrite for ProxyProtocolStream<S>
where S: AsyncWrite
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_pin_mut().poll_write_vectored(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}
#[cfg(test)]
mod tests {
use std::{
io::Cursor,
net::{IpAddr, Ipv4Addr, SocketAddr},
};
use super::{
header::{Address, Protocol},
*,
};
#[tokio::test]
async fn test_parse() {
let mut buf = [0; 1024];
let header = b"PROXY TCP4 127.0.0.1 192.168.0.1 12345 443\r\n";
buf[..header.len()].copy_from_slice(header);
buf[header.len()..].fill(255);
let mut stream = Cursor::new(&buf);
let mut proxied = ProxyProtocolStream::new(&mut stream)
.await
.expect("failed to create stream");
assert_eq!(
proxied.header(),
Some(Header(
Some(Address {
protocol: Protocol::Stream,
source: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 12345),
destination: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)), 443),
}),
vec![0; 0].into(),
))
.as_ref()
);
let mut buf = Vec::new();
AsyncReadExt::read_to_end(&mut proxied, &mut buf)
.await
.expect("failed to read from stream");
assert_eq!(buf.len(), 1024 - header.len());
assert!(buf.into_iter().all(|b| b == 255));
}
}

View File

@@ -0,0 +1,39 @@
use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr},
str::FromStr,
};
pub(super) fn read_until(buf: &[u8], delim: u8) -> Option<&[u8]> {
for i in 0..buf.len() {
if buf[i] == delim {
return Some(&buf[..i]);
}
}
None
}
pub(super) trait AddressFamily: FromStr + Into<IpAddr> {
const BYTES: usize;
fn from_slice(slice: &[u8]) -> Self;
}
impl AddressFamily for Ipv4Addr {
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
const BYTES: usize = (Self::BITS / 8) as usize;
fn from_slice(slice: &[u8]) -> Self {
let arr: [u8; Self::BYTES] = slice.try_into().expect("slice must be 4 bytes");
arr.into()
}
}
impl AddressFamily for Ipv6Addr {
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
const BYTES: usize = (Self::BITS / 8) as usize;
fn from_slice(slice: &[u8]) -> Self {
let arr: [u8; Self::BYTES] = slice.try_into().expect("slice must be 16 bytes");
arr.into()
}
}

View File

@@ -0,0 +1,110 @@
use std::{
borrow::Cow,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
str::{FromStr as _, from_utf8},
};
use super::{
header::{Address, Header, Protocol, ProxyProtocolError},
utils::{AddressFamily, read_until},
};
const MAX_LENGTH: usize = 107;
const GREETING: &[u8] = b"PROXY";
const UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
// All other valid PROXY headers are longer than this
const MIN_LENGTH: usize = UNKNOWN.len();
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, ProxyProtocolError> {
let Some(address) = read_until(&buf[*pos..], b' ') else {
return Err(ProxyProtocolError::BufferTooShort);
};
let addr = from_utf8(address)
.map_err(|_| ProxyProtocolError::Invalid)
.and_then(|s| A::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
*pos += address.len() + 1;
Ok(addr)
}
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, ProxyProtocolError> {
let Some(port) = read_until(&buf[*pos..], delim) else {
return Err(ProxyProtocolError::BufferTooShort);
};
let p = from_utf8(port)
.map_err(|_| ProxyProtocolError::Invalid)
.and_then(|s| u16::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
*pos += port.len() + 1;
Ok(p)
}
fn parse_addrs<A: AddressFamily>(
buf: &[u8],
pos: &mut usize,
) -> Result<Address, ProxyProtocolError> {
let src_addr: A = parse_addr(buf, pos)?;
let dst_addr: A = parse_addr(buf, pos)?;
let src_port = parse_port(buf, pos, b' ')?;
let dst_port = parse_port(buf, pos, b'\r')?;
Ok(Address {
protocol: Protocol::Stream, // v1 only supports TCP
source: SocketAddr::new(src_addr.into(), src_port),
destination: SocketAddr::new(dst_addr.into(), dst_port),
})
}
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
let mut pos = 0;
if buf.len() < MIN_LENGTH {
return Err(ProxyProtocolError::BufferTooShort);
}
if !buf.starts_with(GREETING) {
return Err(ProxyProtocolError::Invalid);
}
pos += GREETING.len() + 1;
let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
let Some(rest) = read_until(&buf[pos..], b'\r') else {
return Err(ProxyProtocolError::BufferTooShort);
};
pos += rest.len() + 1;
None
} else {
let proto = &buf[pos..pos + 5];
pos += 5;
match proto {
b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
_ => return Err(ProxyProtocolError::Invalid),
}
};
match buf.get(pos) {
Some(b'\n') => pos += 1,
None => return Err(ProxyProtocolError::BufferTooShort),
_ => return Err(ProxyProtocolError::Invalid),
}
Ok((Header(addrs, Cow::default()), pos))
}
/// Decode a version 1 PROXY header from a buffer.
///
/// Returns the decoded header and the number of bytes consumed from the buffer.
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
// Guard against a malicious client sending a very long header, since it is a
// delimited protocol.
match decode_inner(buf) {
Err(ProxyProtocolError::BufferTooShort) if buf.len() >= MAX_LENGTH => {
Err(ProxyProtocolError::Invalid)
}
other => other,
}
}

View File

@@ -0,0 +1,122 @@
use std::{
borrow::Cow,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
};
use super::{
header::{Address, Header, Protocol, ProxyProtocolError},
utils::AddressFamily,
};
const GREETING: &[u8] = b"\r\n\r\n\x00\r\nQUIT\n";
const MIN_LENGTH: usize = GREETING.len() + 4;
const AF_UNIX_ADDRS_LEN: usize = 216;
fn parse_addrs<T: AddressFamily>(
buf: &[u8],
pos: &mut usize,
rest: &mut usize,
protocol: Protocol,
) -> Result<Address, ProxyProtocolError> {
if buf.len() < *pos + T::BYTES * 2 + 4 {
return Err(ProxyProtocolError::BufferTooShort);
}
if *rest < T::BYTES * 2 + 4 {
return Err(ProxyProtocolError::Invalid);
}
let addr = Address {
protocol,
source: SocketAddr::new(
T::from_slice(&buf[*pos..*pos + T::BYTES]).into(),
u16::from_be_bytes([buf[*pos + T::BYTES * 2], buf[*pos + T::BYTES * 2 + 1]]),
),
destination: SocketAddr::new(
T::from_slice(&buf[*pos + T::BYTES..*pos + T::BYTES * 2]).into(),
u16::from_be_bytes([buf[*pos + T::BYTES * 2 + 2], buf[*pos + T::BYTES * 2 + 3]]),
),
};
*rest -= T::BYTES * 2 + 4;
*pos += T::BYTES * 2 + 4;
Ok(addr)
}
/// Decode a version 2 PROXY header from a buffer.
///
/// Returns the decoded header and the number of bytes consumed from the buffer.
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
let mut pos = 0;
if buf.len() < MIN_LENGTH {
return Err(ProxyProtocolError::BufferTooShort);
}
if !buf.starts_with(GREETING) {
return Err(ProxyProtocolError::Invalid);
}
pos += GREETING.len();
let is_local = match buf[pos] {
0x20 => true,
0x21 => false,
_ => return Err(ProxyProtocolError::Invalid),
};
let protocol = buf[pos + 1];
let mut rest: usize = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]).into();
pos += 4;
if buf.len() < pos + rest {
return Err(ProxyProtocolError::BufferTooShort);
}
let addr_info = match protocol {
0x00 => None,
0x11 => Some(parse_addrs::<Ipv4Addr>(
buf,
&mut pos,
&mut rest,
Protocol::Stream,
)?),
0x12 => Some(parse_addrs::<Ipv4Addr>(
buf,
&mut pos,
&mut rest,
Protocol::Datagram,
)?),
0x21 => Some(parse_addrs::<Ipv6Addr>(
buf,
&mut pos,
&mut rest,
Protocol::Stream,
)?),
0x22 => Some(parse_addrs::<Ipv6Addr>(
buf,
&mut pos,
&mut rest,
Protocol::Datagram,
)?),
0x31 | 0x32 => {
// AF_UNIX - we don't parse it, but don't reject it either in case we need the TLVs
if rest < AF_UNIX_ADDRS_LEN {
return Err(ProxyProtocolError::Invalid);
}
rest -= AF_UNIX_ADDRS_LEN;
pos += AF_UNIX_ADDRS_LEN;
None
}
_ => return Err(ProxyProtocolError::Invalid),
};
let tlv_data = Cow::Borrowed(&buf[pos..pos + rest]);
pos += rest;
let header = if is_local {
Header(None, tlv_data)
} else {
Header(addr_info, tlv_data)
};
Ok((header, pos))
}

123
src/tracing.rs Normal file
View File

@@ -0,0 +1,123 @@
use eyre::Result;
use tracing_error::ErrorLayer;
use tracing_subscriber::{filter::EnvFilter, fmt, prelude::*};
use crate::config;
pub(super) fn install() -> Result<()> {
let config = config::get();
let mut filter_layer = EnvFilter::builder()
.with_default_directive(config.log_level.parse()?)
.parse(&config.log_level)?;
for (k, v) in &config.log.rust_log {
filter_layer = filter_layer.add_directive(format!("{k}={v}").parse()?);
}
if config.debug {
let console_layer = console_subscriber::ConsoleLayer::builder()
.server_addr(config.listen.debug)
.spawn();
tracing_subscriber::registry()
.with(ErrorLayer::default())
.with(console_layer)
.with(
fmt::layer()
.compact()
.event_format(
fmt::format()
.with_thread_ids(true)
.with_thread_names(true)
.with_source_location(true)
.compact(),
)
.with_writer(std::io::stderr)
.with_filter(filter_layer),
)
.with(::sentry::integrations::tracing::layer())
.init();
} else {
tracing_subscriber::registry()
.with(ErrorLayer::default())
.with(json::layer().with_filter(filter_layer))
.with(::sentry::integrations::tracing::layer())
.init();
}
Ok(())
}
pub(super) fn install_crude() -> tracing::dispatcher::DefaultGuard {
let filter_layer = EnvFilter::builder()
.parse("trace,console_subscriber=info,runtime=info,tokio=info,tungstenite=info")
.expect("infallible");
let subscriber = tracing_subscriber::registry()
.with(ErrorLayer::default())
.with(filter_layer)
.with(json::layer());
tracing::dispatcher::set_default(&subscriber.into())
}
mod json {
use std::collections::HashMap;
use tracing::Subscriber;
use tracing_subscriber::{layer::Layer, registry::LookupSpan};
pub(super) fn layer<S>() -> impl Layer<S>
where S: Subscriber + for<'lookup> LookupSpan<'lookup> {
let mut json_layer = json_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true)
.flatten_event(true)
.flatten_current_span_on_top_level(true);
let inner_layer = json_layer.inner_layer_mut();
inner_layer.with_thread_ids("thread_id");
inner_layer.with_thread_names("thread_name");
inner_layer.add_dynamic_field("pid", |_, _| {
Some(serde_json::Value::Number(serde_json::Number::from(
std::process::id(),
)))
});
inner_layer.with_flattened_event_with_renames(
move |name, map| match map.get(name) {
Some(name) => name.as_str(),
None => name,
},
HashMap::from([("message".to_owned(), "event".to_owned())]),
);
json_layer
}
}
pub(crate) mod sentry {
use std::str::FromStr as _;
use tracing::trace;
use crate::{VERSION, authentik_user_agent, config};
pub(crate) fn install() -> sentry::ClientInitGuard {
trace!("setting up sentry");
let config = config::get();
sentry::init(sentry::ClientOptions {
dsn: config.error_reporting.sentry_dsn.clone().map(|dsn| {
sentry::types::Dsn::from_str(&dsn).expect("Failed to create sentry DSN")
}),
release: Some(format!("authentik@{VERSION}").into()),
environment: Some(config.error_reporting.environment.clone().into()),
attach_stacktrace: true,
send_default_pii: config.error_reporting.send_pii,
sample_rate: config.error_reporting.sample_rate,
traces_sample_rate: if config.debug {
1.0
} else {
config.error_reporting.sample_rate
},
user_agent: authentik_user_agent().into(),
..sentry::ClientOptions::default()
})
}
}

366
src/worker.rs Normal file
View File

@@ -0,0 +1,366 @@
use std::{
env::temp_dir,
os::unix,
path::PathBuf,
process::Stdio,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use argh::FromArgs;
use axum::body::Body;
use eyre::{Result, eyre};
use hyper_unix_socket::UnixSocketConnector;
use hyper_util::{client::legacy::Client, rt::TokioExecutor};
use nix::{
sys::signal::{Signal, kill},
unistd::Pid,
};
use tokio::{
net::UnixStream,
process::{Child, Command},
signal::unix::SignalKind,
sync::{Mutex, broadcast::error::RecvError},
};
use tracing::{info, trace, warn};
use crate::{
arbiter::{Arbiter, Tasks},
axum::server,
config,
mode::Mode,
};
#[derive(Debug, Default, FromArgs, PartialEq)]
/// Run the authentik worker.
#[argh(subcommand, name = "worker")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(crate) struct Cli {}
const INITIAL_WORKER_ID: usize = 1000;
static INITIAL_WORKER_READY: AtomicBool = AtomicBool::new(false);
struct Worker(Child);
impl Worker {
fn new(worker_id: usize, socket_path: Option<&str>) -> Result<Self> {
info!(worker_id, "Starting worker");
let mut cmd = Command::new("python");
cmd.args(["-m", "lifecycle.worker_process", &worker_id.to_string()]);
if let Some(socket_path) = socket_path {
cmd.arg(socket_path);
}
Ok(Self(
cmd.kill_on_drop(true)
.stdout(Stdio::inherit())
.stderr(Stdio::inherit())
.spawn()?,
))
}
async fn shutdown(&mut self, signal: Signal) -> Result<()> {
trace!(
signal = signal.as_str(),
"sending shutdown signal to worker"
);
if let Some(id) = self.0.id() {
kill(Pid::from_raw(id.cast_signed()), signal)?;
}
self.0.wait().await?;
Ok(())
}
async fn graceful_shutdown(&mut self) -> Result<()> {
info!("gracefully shutting down worker");
self.shutdown(Signal::SIGTERM).await
}
async fn fast_shutdown(&mut self) -> Result<()> {
info!("immediately shutting down worker");
self.shutdown(Signal::SIGINT).await
}
fn is_alive(&mut self) -> bool {
let try_wait = self.0.try_wait();
match try_wait {
Ok(Some(code)) => {
warn!("worker has exited with status {code}");
false
}
Ok(None) => true,
Err(err) => {
warn!("failed to check the status of worker process, ignoring: {err}");
true
}
}
}
}
pub(crate) struct Workers {
workers: Mutex<Vec<Worker>>,
socket_path: PathBuf,
pub(crate) client: Client<UnixSocketConnector<PathBuf>, Body>,
}
impl Workers {
fn new(socket_path: PathBuf) -> Result<Self> {
let mut workers = Vec::with_capacity(config::get().worker.processes.get());
workers.push(Worker::new(
INITIAL_WORKER_ID,
Some(&format!("{}", &socket_path.display())),
)?);
let client = Client::builder(TokioExecutor::new())
.pool_idle_timeout(Duration::from_secs(60))
.set_host(false)
.build(UnixSocketConnector::new(socket_path.clone()));
Ok(Self {
workers: Mutex::new(workers),
socket_path,
client,
})
}
async fn start_other_workers(&self) -> Result<()> {
for i in 1..config::get().worker.processes.get() {
self.workers
.lock()
.await
.push(Worker::new(INITIAL_WORKER_ID + i, None)?);
}
Ok(())
}
async fn graceful_shutdown(&self) -> Result<()> {
let mut results = Vec::with_capacity(self.workers.lock().await.capacity());
for worker in self.workers.lock().await.iter_mut() {
results.push(worker.graceful_shutdown().await);
}
results.into_iter().find(Result::is_err).unwrap_or(Ok(()))
}
async fn fast_shutdown(&self) -> Result<()> {
let mut results = Vec::with_capacity(self.workers.lock().await.capacity());
for worker in self.workers.lock().await.iter_mut() {
results.push(worker.fast_shutdown().await);
}
results.into_iter().find(Result::is_err).unwrap_or(Ok(()))
}
pub(crate) async fn are_alive(&self) -> bool {
for worker in self.workers.lock().await.iter_mut() {
if !worker.is_alive() {
return false;
}
}
true
}
async fn is_socket_ready(&self) -> bool {
let result = UnixStream::connect(&self.socket_path).await;
trace!("checking if worker socket is ready: {result:?}");
result.is_ok()
}
}
async fn watch_workers(arbiter: Arbiter, workers: Arc<Workers>) -> Result<()> {
info!("starting worker watcher");
let mut signals_rx = arbiter.signals_subscribe();
loop {
tokio::select! {
signal = signals_rx.recv() => {
match signal {
Ok(signal) => {
if signal == SignalKind::user_defined2() {
info!("worker notified us ready, marked ready for operation");
INITIAL_WORKER_READY.store(true, Ordering::Relaxed);
workers.start_other_workers().await?;
}
},
Err(RecvError::Lagged(_)) => {},
Err(RecvError::Closed) => {
warn!("error receiving signals");
return Err(RecvError::Closed.into());
}
}
},
() = tokio::time::sleep(Duration::from_secs(1)), if !INITIAL_WORKER_READY.load(Ordering::Relaxed) => {
// On some platforms the SIGUSR1 can be missed.
// Fall back to probing the worker unix socket and mark ready once it accepts connections.
if workers.is_socket_ready().await {
info!("worker socket is accepting connections, marked ready for operation");
INITIAL_WORKER_READY.store(true, Ordering::Relaxed);
workers.start_other_workers().await?;
}
},
() = tokio::time::sleep(Duration::from_secs(5)) => {
if !workers.are_alive().await {
return Err(eyre!("gunicorn has exited unexpectedly"));
}
},
() = arbiter.fast_shutdown() => {
workers.fast_shutdown().await?;
return Ok(());
},
() = arbiter.graceful_shutdown() => {
workers.graceful_shutdown().await?;
return Ok(());
},
}
}
}
mod healthcheck {
use std::sync::Arc;
use axum::{
Router,
body::Body,
extract::{Request, State},
http::{StatusCode, header::HOST},
response::IntoResponse,
routing::any,
};
use crate::{axum::router::wrap_router, db, worker::Workers};
async fn health_ready(State(workers): State<Arc<Workers>>) -> impl IntoResponse {
if !workers.are_alive().await || sqlx::query("SELECT 1").execute(db::get()).await.is_err() {
StatusCode::SERVICE_UNAVAILABLE
} else {
let req = Request::builder()
.method("GET")
.uri("http://localhost:8000/-/health/ready/")
.header(HOST, "localhost")
.body(Body::from(""));
if let Ok(req) = req
&& let Ok(res) = workers.client.request(req).await
{
res.status()
} else {
StatusCode::SERVICE_UNAVAILABLE
}
}
}
async fn health_live(State(workers): State<Arc<Workers>>) -> impl IntoResponse {
let req = Request::builder()
.method("GET")
.uri("http://localhost:8000/-/health/live/")
.header(HOST, "localhost")
.body(Body::from(""));
if let Ok(req) = req
&& let Ok(res) = workers.client.request(req).await
{
res.status()
} else {
StatusCode::SERVICE_UNAVAILABLE
}
}
async fn fallback() -> impl IntoResponse {
StatusCode::OK
}
pub(super) fn build_router(workers: Arc<Workers>) -> Router {
wrap_router(
Router::new()
.route("/-/heath/ready/", any(health_ready))
.route("/-/heath/live/", any(health_live))
.fallback(fallback)
.with_state(workers),
true,
)
}
}
mod worker_status {
use std::time::Duration;
use eyre::Result;
use nix::unistd::gethostname;
use uuid::Uuid;
use crate::{arbiter::Arbiter, authentik_full_version, db};
async fn keep(arbiter: Arbiter, id: Uuid, hostname: &str, version: &str) -> Result<()> {
loop {
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(30)) => {
sqlx::query("
INSERT INTO authentik_tasks_workerstatus (id, hostname, version, last_seen)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (id) DO UPDATE SET last_seen = NOW()
")
.bind(id)
.bind(hostname)
.bind(version)
.execute(db::get())
.await?;
},
() = arbiter.shutdown() => return Ok(()),
}
}
}
pub(super) async fn run(arbiter: Arbiter) -> Result<()> {
let id = Uuid::new_v4();
let raw_hostname = gethostname()?;
let hostname = raw_hostname.to_string_lossy();
let version = authentik_full_version();
loop {
tokio::select! {
_ = keep(arbiter.clone(), id, hostname.as_ref(), &version) => {
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(10)) => {},
() = arbiter.shutdown() => return Ok(()),
}
},
() = arbiter.shutdown() => return Ok(()),
}
}
}
}
pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Workers>> {
let arbiter = tasks.arbiter();
let workers = Arc::new(Workers::new(temp_dir().join("authentik-worker.sock"))?);
tasks
.build_task()
.name(&format!("{}::watch_workers", module_path!()))
.spawn(watch_workers(arbiter.clone(), Arc::clone(&workers)))?;
tasks
.build_task()
.name(&format!("{}::worker_status::run", module_path!()))
.spawn(worker_status::run(arbiter))?;
if Mode::get() == Mode::Worker {
let router = healthcheck::build_router(Arc::clone(&workers));
for addr in config::get().listen.http.iter().copied() {
server::start_plain(tasks, "worker", router.clone(), addr)?;
}
server::start_unix(
tasks,
"worker",
router,
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
)?;
}
Ok(workers)
}

View File

@@ -6,17 +6,58 @@ title: Reverse proxy
Since authentik uses WebSockets to communicate with Outposts, it does not support HTTP/1.0 reverse proxies. The HTTP/1.0 specification does not officially support WebSockets or protocol upgrades, though some clients may allow it.
:::
If you want to access authentik behind a reverse proxy, there are a few headers that must be passed upstream:
- `X-Forwarded-Proto`: Tells authentik and Proxy Providers if they are being served over an HTTPS connection.
- `X-Forwarded-For`: Without this, authentik will not know the IP addresses of clients.
- `Host`: Required for various security checks, WebSocket handshake, and Outpost and Proxy Provider communication.
- `Connection: Upgrade` and `Upgrade: WebSocket`: Required to upgrade protocols for requests to the WebSocket endpoints under HTTP/1.1.
If you want to access authentik behind a reverse proxy, there are a few headers that must be passed upstream for authentik to be able to correctly identify a connection.
It is also recommended to use a [modern TLS configuration](https://ssl-config.mozilla.org/) and disable SSL/TLS protocols older than TLS 1.3.
If your reverse proxy isn't accessing authentik from a private IP address, [trusted proxy CIDRs configuration](./configuration/configuration.mdx#listen-settings) needs to be set on the authentik server to allow client IP address detection.
### Scheme
authentik and Proxy Providers need to know if they are being served over an HTTPS connection.
The connection scheme (HTTP/HTTPS) is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
- `X-Forwarded-Proto` header,
- `X-Forwarded-Scheme` header,
- `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239). If multiple `proto=` stanzas are present, only the first one is retained.
- whether the connection was made via TLS to the proxy via the PROXY protocol if used.
If the connection is not trusted, or the above is missing, authentik will look at whether the connection was made over plaintext or TLS.
### Host
Required for various security checks, WebSocket handshake, and Outpost and Proxy Provider communication.
The Host is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
- `X-Forwarded-Host` header,
- `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239). If multiple `host=` stanzas are present, only the first one is retained.
If the connection is not trusted, or the above is missing, authentik will consider the following:
- `Host` header,
- host part of the request URL.
### Client IP
authentik needs to know the IP addresses of clients for various security features and for audit purposes.
The client IP is grabbed as follows. If the incoming connection is from a trusted proxy, the following is considered:
- the rightmost IP in the `X-Forwarded-For` header,
- `X-Real-IP` header,
- the rightmost IP in the `Forwarded` header, as defined in [RFC 7239](https://datatracker.ietf.org/doc/html/rfc7239),
- the IP passed via PROXY Protocol if used.
If the connection is not trusted, the client IP will be extracted from the TCP metadata.
### WebSockets
The `Connection: Upgrade` and `Upgrade: WebSocket` headers are required to upgrade protocols for requests to the WebSocket endpoints under HTTP/1.1.
### Example configuration
The following nginx configuration can be used as a starting point for your own configuration.
```