diff --git a/Cargo.lock b/Cargo.lock index 53dffa2371..238ae9c880 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,15 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "arc-swap" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" +dependencies = [ + "rustversion", +] + [[package]] name = "async-trait" version = "0.1.89" @@ -105,6 +114,14 @@ dependencies = [ [[package]] name = "authentik-common" version = "2026.5.0-rc1" +dependencies = [ + "axum-server", + "eyre", + "nix", + "tokio", + "tokio-util", + "tracing", +] [[package]] name = "autocfg" @@ -150,6 +167,28 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "axum-server" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1df331683d982a0b9492b38127151e6453639cd34926eb9c07d4cd8c6d22bfc" +dependencies = [ + "arc-swap", + "bytes", + "either", + "fs-err", + "http", + "http-body", + "hyper", + "hyper-util", + "pin-project-lite", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + [[package]] name = "base64" version = "0.22.1" @@ -465,6 +504,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs-err" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0" +dependencies = [ + "autocfg", + "tokio", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -662,6 +711,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.8.1" @@ -676,6 +731,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -1040,6 +1096,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nix" +version = "0.31.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +dependencies = [ + "bitflags", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -1912,9 +1980,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tracing-core" version = "0.1.36" diff --git a/Cargo.toml b/Cargo.toml index 1b6eed7ccb..e15625cacd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,11 +18,13 @@ license-file = "LICENSE" publish = false [workspace.dependencies] +axum-server = { version = "= 0.8.0", features = ["tls-rustls-no-provider"] } aws-lc-rs = { version = "= 1.16.2", features = ["fips"] } clap = { version = "= 4.6.0", features = ["derive", "env"] } colored = "= 3.1.1" dotenvy = "= 0.15.7" eyre = "= 0.6.12" +nix = { version = "= 0.31.2", features = ["signal"] } regex = "= 1.12.3" reqwest = { version = "= 0.13.2", features = [ "form", @@ -48,6 +50,7 @@ serde_with = { version = "= 3.18.0", default-features = false, features = [ ] } tokio = { version = "= 1.50.0", features = ["full", "tracing"] } tokio-util = { version = "= 0.7.18", features = ["full"] } +tracing = "= 0.1.44" url = "= 2.5.8" uuid = { version = "= 1.23.0", features = ["serde", "v4"] } diff --git a/packages/ak-common/Cargo.toml b/packages/ak-common/Cargo.toml index a7b1ed133b..399a9f007f 100644 --- a/packages/ak-common/Cargo.toml +++ b/packages/ak-common/Cargo.toml @@ -10,8 +10,14 @@ license-file.workspace = true publish.workspace = true [dependencies] +axum-server.workspace = true +eyre.workspace = true +tokio.workspace = true +tokio-util.workspace = true +tracing.workspace = true [dev-dependencies] +nix.workspace = true [lints] workspace = true diff --git a/packages/ak-common/src/arbiter.rs b/packages/ak-common/src/arbiter.rs new file mode 100644 index 0000000000..03648829e3 --- /dev/null +++ b/packages/ak-common/src/arbiter.rs @@ -0,0 +1,480 @@ +//! Utilities to manage long running tasks, such as servers and watchers, and events propagated +//! between those tasks. +//! +//! Also manages signals sent to the main process. + +use std::{net, os::unix, sync::Arc, time::Duration}; + +use axum_server::Handle; +use eyre::{Report, Result}; +use tokio::{ + signal::unix::{Signal, SignalKind, signal}, + sync::{Mutex, broadcast}, + task::{JoinSet, join_set::Builder}, +}; +use tokio_util::sync::{CancellationToken, WaitForCancellationFuture}; +use tracing::info; + +/// All the signal streams we watch for. We don't create those directly in [`watch_signals`] +/// because that would prevent us from handling errors early. +struct SignalStreams { + hup: Signal, + int: Signal, + quit: Signal, + usr1: Signal, + usr2: Signal, + term: Signal, +} + +impl SignalStreams { + fn new() -> Result { + Ok(Self { + hup: signal(SignalKind::hangup())?, + int: signal(SignalKind::interrupt())?, + quit: signal(SignalKind::quit())?, + usr1: signal(SignalKind::user_defined1())?, + usr2: signal(SignalKind::user_defined2())?, + term: signal(SignalKind::terminate())?, + }) + } +} + +/// Watch for incoming signals and either shutdown the application or dispatch them to receivers. +async fn watch_signals(streams: SignalStreams, arbiter: Arbiter) -> Result<()> { + info!("starting signals watcher"); + let SignalStreams { + mut hup, + mut int, + mut quit, + mut usr1, + mut usr2, + mut term, + } = streams; + loop { + tokio::select! { + _ = hup.recv() => { + info!("signal HUP received"); + arbiter.do_fast_shutdown().await; + }, + _ = int.recv() => { + info!("signal INT received"); + arbiter.do_fast_shutdown().await; + }, + _ = quit.recv() => { + info!("signal QUIT received"); + arbiter.do_fast_shutdown().await; + }, + _ = term.recv() => { + info!("signal TERM received"); + arbiter.do_graceful_shutdown().await; + }, + _ = usr1.recv() => { + info!("signal URS1 received"); + let _ = arbiter.send_event(SignalKind::user_defined1().into()); + }, + _ = usr2.recv() => { + info!("USR2 received."); + let _ = arbiter.send_event(SignalKind::user_defined2().into()); + }, + () = arbiter.shutdown() => { + info!("stopping signals watcher"); + return Ok(()); + } + }; + } +} + +/// Manager for long running tasks, such as servers and watchers. +pub struct Tasks { + tasks: JoinSet>, + arbiter: Arbiter, +} + +impl Tasks { + /// Create a new [`Tasks`] manager. + /// + /// # Errors + /// + /// Errors if the creation of signals watcher fails. + pub fn new() -> Result { + let mut tasks = JoinSet::new(); + let arbiter = Arbiter::new(&mut tasks)?; + + Ok(Self { tasks, arbiter }) + } + + /// Build a new task. See [`tokio::task::JoinSet::build_task`] for details. + pub fn build_task(&mut self) -> Builder<'_, Result<()>> { + self.tasks.build_task() + } + + /// Get an [`Arbiter`]. + pub fn arbiter(&self) -> Arbiter { + self.arbiter.clone() + } + + /// Run the tasks until completion. If one of them fails, terminate the program immediately. + pub async fn run(self) -> Vec { + let Self { mut tasks, arbiter } = self; + + let mut errors = Vec::new(); + + if let Some(result) = tasks.join_next().await { + arbiter.do_graceful_shutdown().await; + + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => { + arbiter.do_fast_shutdown().await; + errors.push(err); + } + Err(err) => { + arbiter.do_fast_shutdown().await; + errors.push(Report::new(err)); + } + } + + while let Some(result) = tasks.join_next().await { + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => errors.push(err), + Err(err) => errors.push(Report::new(err)), + } + } + } + + errors + } +} + +/// Manage shutdown state and several communication channels. +#[derive(Clone)] +pub struct Arbiter { + /// Token to shutdown the application immediately. + fast_shutdown: CancellationToken, + /// Token to shutdown the application gracefully. + graceful_shutdown: CancellationToken, + /// Token set when any shutdown is triggered. + shutdown: CancellationToken, + + /// axum-server [`Handle`] to manage. + net_handles: Arc>>>, + unix_handles: Arc>>>, + + /// Broadcaster of program-wide events, except shutdown which is handled by tokens above. + events_tx: broadcast::Sender, +} + +impl Arbiter { + fn new(tasks: &mut JoinSet>) -> Result { + let (events_tx, _events_rx) = broadcast::channel(1024); + let arbiter = Self { + fast_shutdown: CancellationToken::new(), + graceful_shutdown: CancellationToken::new(), + shutdown: CancellationToken::new(), + + // 5 is http, https, metrics and a bit of room + net_handles: Arc::new(Mutex::new(Vec::with_capacity(5))), + // 2 is http and metrics + unix_handles: Arc::new(Mutex::new(Vec::with_capacity(2))), + + events_tx, + }; + + let streams = SignalStreams::new()?; + + tasks + .build_task() + .name(&format!("{}::watch_signals", module_path!())) + .spawn(watch_signals(streams, arbiter.clone()))?; + + Ok(arbiter) + } + + /// Add a new [`Handle`] to be managed, specifically for [`net::SocketAddr`] addresses. + /// + /// This handle will be shutdown when this arbiter is shutdown. + pub async fn add_net_handle(&self, handle: Handle) { + self.net_handles.lock().await.push(handle); + } + + /// Add a new [`Handle`] to be managed, specifically for [`unix::net::SocketAddr`] addresses. + /// + /// This handle will be shutdown when this arbiter is shutdown. + pub async fn add_unix_handle(&self, handle: Handle) { + self.unix_handles.lock().await.push(handle); + } + + /// Future that will complete when the application needs to shutdown immediately. + /// + /// Consumers listening on this must also listen on [`Arbiter::graceful_shutdown`], as only one + /// of those is set upon shutdown. + /// + /// It is also possible to use [`Arbiter::shutdown`] when the behaviour is the same between a + /// fast and a graceful shutdown. + pub fn fast_shutdown(&self) -> WaitForCancellationFuture<'_> { + self.fast_shutdown.cancelled() + } + + /// Future that will complete when the application needs to shutdown gracefully. + /// + /// Consumers listening on this must also listen on [`Arbiter::fast_shutdown`], as only one + /// of those is set upon shutdown. + /// + /// It is also possible to use [`Arbiter::shutdown`] when the behaviour is the same between a + /// fast and a graceful shutdown. + pub fn graceful_shutdown(&self) -> WaitForCancellationFuture<'_> { + self.graceful_shutdown.cancelled() + } + + /// Future that will complete when the application needs to shutdown, either immediately or + /// gracefully. It's a helper so users that don't make the difference between immediate and + /// graceful shutdown don't need to handle two scenarios. + pub fn shutdown(&self) -> WaitForCancellationFuture<'_> { + self.shutdown.cancelled() + } + + /// Shutdown the application immediately. + async fn do_fast_shutdown(&self) { + info!("arbiter has been told to shutdown immediately"); + self.unix_handles + .lock() + .await + .iter() + .for_each(Handle::shutdown); + self.net_handles + .lock() + .await + .iter() + .for_each(Handle::shutdown); + info!("all webservers have been shutdown, shutting down the other tasks immediately"); + self.fast_shutdown.cancel(); + self.shutdown.cancel(); + } + + /// Shutdown the application gracefully. + async fn do_graceful_shutdown(&self) { + info!("arbiter has been told to shutdown gracefully"); + // Match the value in lifecycle/gunicorn.conf.py for graceful shutdown + let timeout = Some(Duration::from_secs(30 + 5)); + self.unix_handles + .lock() + .await + .iter() + .for_each(|handle| handle.graceful_shutdown(timeout)); + self.net_handles + .lock() + .await + .iter() + .for_each(|handle| handle.graceful_shutdown(timeout)); + info!("all webservers have been shutdown, shutting down the other tasks gracefully"); + self.graceful_shutdown.cancel(); + self.shutdown.cancel(); + } + + /// Create a new [`broadcast::Receiver`] to listen for signals sent to the main process. This + /// may not include all signals we catch, since some of those will shutdown the application. + pub fn events_subscribe(&self) -> broadcast::Receiver { + self.events_tx.subscribe() + } + + /// Send a value on the config changes watch channel. + /// + /// # Errors + /// + /// See [`broadcast::Sender::send`]. + pub fn send_event(&self, value: Event) -> Result> { + self.events_tx.send(value) + } +} + +/// Events propagated throughout the program. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Event { + /// A signal has been received. + Signal(SignalKind), + #[cfg(test)] + Noop, +} + +impl From for Event { + fn from(value: SignalKind) -> Self { + Self::Signal(value) + } +} + +#[cfg(test)] +mod tests { + mod events { + use nix::sys::signal::{Signal, raise}; + + use super::super::*; + + async fn signal_self(signal: Signal) { + raise(signal).expect("failed to send signal"); + tokio::time::sleep(Duration::from_millis(50)).await; + } + + #[tokio::test] + async fn signals_hup() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + signal_self(Signal::SIGHUP).await; + + assert!(arbiter.fast_shutdown.is_cancelled()); + assert!(!arbiter.graceful_shutdown.is_cancelled()); + assert!(arbiter.shutdown.is_cancelled()); + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn signals_quit() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + signal_self(Signal::SIGQUIT).await; + + assert!(arbiter.fast_shutdown.is_cancelled()); + assert!(!arbiter.graceful_shutdown.is_cancelled()); + assert!(arbiter.shutdown.is_cancelled()); + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn signals_int() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + signal_self(Signal::SIGINT).await; + + assert!(arbiter.fast_shutdown.is_cancelled()); + assert!(!arbiter.graceful_shutdown.is_cancelled()); + assert!(arbiter.shutdown.is_cancelled()); + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn signals_term() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + signal_self(Signal::SIGTERM).await; + + assert!(!arbiter.fast_shutdown.is_cancelled()); + assert!(arbiter.graceful_shutdown.is_cancelled()); + assert!(arbiter.shutdown.is_cancelled()); + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn signals_other_no_listener() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + signal_self(Signal::SIGUSR1).await; + signal_self(Signal::SIGUSR2).await; + + arbiter.do_fast_shutdown().await; + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn signals_usr1() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + let mut events_rx = arbiter.events_subscribe(); + + signal_self(Signal::SIGUSR1).await; + + assert_eq!( + events_rx.recv().await.expect("failed to receive event"), + Event::Signal(SignalKind::user_defined1()) + ); + } + + #[tokio::test] + async fn signals_usr2() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + let mut events_rx = arbiter.events_subscribe(); + + signal_self(Signal::SIGUSR2).await; + + assert_eq!( + events_rx.recv().await.expect("failed to receive event"), + Event::Signal(SignalKind::user_defined2()), + ); + } + + #[tokio::test] + async fn events() { + let tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + let mut events_rx1 = arbiter.events_subscribe(); + let mut events_rx2 = arbiter.events_subscribe(); + + let _ = arbiter.send_event(Event::Noop); + + assert_eq!( + events_rx1.recv().await.expect("failed to receive event"), + Event::Noop, + ); + assert_eq!( + events_rx2.recv().await.expect("failed to receive event"), + Event::Noop, + ); + } + } + + mod tasks { + use eyre::eyre; + + use super::super::*; + + async fn success_task(arbiter: Arbiter) -> Result<()> { + tokio::select! { + () = arbiter.fast_shutdown() => {}, + () = arbiter.graceful_shutdown() => {}, + } + Ok(()) + } + + async fn error_task(arbiter: Arbiter) -> Result<()> { + arbiter.shutdown().await; + Err(eyre!("error")) + } + + #[tokio::test] + async fn successful_tasks() { + let mut tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + for _ in 0..10_u8 { + tasks + .build_task() + .spawn(success_task(arbiter.clone())) + .expect("failed to spawn task"); + } + arbiter.do_fast_shutdown().await; + + assert_eq!(tasks.run().await.len(), 0); + } + + #[tokio::test] + async fn error_tasks() { + let mut tasks = Tasks::new().expect("tasks to create successfully"); + let arbiter = tasks.arbiter(); + + for _ in 0..10_u8 { + tasks + .build_task() + .spawn(error_task(arbiter.clone())) + .expect("failed to spawn task"); + } + arbiter.do_fast_shutdown().await; + + assert_eq!(tasks.run().await.len(), 10); + } + } +} diff --git a/packages/ak-common/src/lib.rs b/packages/ak-common/src/lib.rs index e63d65161f..2e16a8890c 100644 --- a/packages/ak-common/src/lib.rs +++ b/packages/ak-common/src/lib.rs @@ -1,5 +1,8 @@ //! Various utilities used by other crates +pub mod arbiter; +pub use arbiter::{Arbiter, Event, Tasks}; + pub const VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn authentik_build_hash(fallback: Option) -> String { diff --git a/website/scripts/docsmg/src/hackyfixes.rs b/website/scripts/docsmg/src/hackyfixes.rs index 5fd4c6dbcb..88f2e8f4ad 100644 --- a/website/scripts/docsmg/src/hackyfixes.rs +++ b/website/scripts/docsmg/src/hackyfixes.rs @@ -22,9 +22,3 @@ pub(crate) fn add_extra_dot_dot_to_expression_mdx(migrate_path: &Path) { let _ = write(file, content.replace("../expressions", "../../expressions")); } } - -#[cfg(test)] -mod tests { - #[test] - fn noop() {} -}