packages/ak-common/arbiter: init (#21253)

* packages/ak-arbiter: init

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fixup

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* add tests

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* lint

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* sort out package versions

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* rename to ak-lib

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fixup

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* packages/ak-lib: init

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* fixup

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* root: fix rustfmt config

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

* packages/ak-common: rename from ak-lib

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>

---------

Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2026-04-02 12:06:28 +00:00
committed by GitHub
parent b3036776ed
commit d3fca338b3
6 changed files with 572 additions and 6 deletions

80
Cargo.lock generated
View File

@@ -67,6 +67,15 @@ version = "1.0.102"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c"
[[package]]
name = "arc-swap"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6"
dependencies = [
"rustversion",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -105,6 +114,14 @@ dependencies = [
[[package]]
name = "authentik-common"
version = "2026.5.0-rc1"
dependencies = [
"axum-server",
"eyre",
"nix",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "autocfg"
@@ -150,6 +167,28 @@ dependencies = [
"fs_extra",
]
[[package]]
name = "axum-server"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1df331683d982a0b9492b38127151e6453639cd34926eb9c07d4cd8c6d22bfc"
dependencies = [
"arc-swap",
"bytes",
"either",
"fs-err",
"http",
"http-body",
"hyper",
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower-service",
]
[[package]]
name = "base64"
version = "0.22.1"
@@ -465,6 +504,16 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs-err"
version = "3.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73fde052dbfc920003cfd2c8e2c6e6d4cc7c1091538c3a24226cec0665ab08c0"
dependencies = [
"autocfg",
"tokio",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
@@ -662,6 +711,12 @@ version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "1.8.1"
@@ -676,6 +731,7 @@ dependencies = [
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"pin-utils",
@@ -1040,6 +1096,18 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "nix"
version = "0.31.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3"
dependencies = [
"bitflags",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "nom"
version = "7.1.3"
@@ -1912,9 +1980,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.36"

View File

@@ -18,11 +18,13 @@ license-file = "LICENSE"
publish = false
[workspace.dependencies]
axum-server = { version = "= 0.8.0", features = ["tls-rustls-no-provider"] }
aws-lc-rs = { version = "= 1.16.2", features = ["fips"] }
clap = { version = "= 4.6.0", features = ["derive", "env"] }
colored = "= 3.1.1"
dotenvy = "= 0.15.7"
eyre = "= 0.6.12"
nix = { version = "= 0.31.2", features = ["signal"] }
regex = "= 1.12.3"
reqwest = { version = "= 0.13.2", features = [
"form",
@@ -48,6 +50,7 @@ serde_with = { version = "= 3.18.0", default-features = false, features = [
] }
tokio = { version = "= 1.50.0", features = ["full", "tracing"] }
tokio-util = { version = "= 0.7.18", features = ["full"] }
tracing = "= 0.1.44"
url = "= 2.5.8"
uuid = { version = "= 1.23.0", features = ["serde", "v4"] }

View File

@@ -10,8 +10,14 @@ license-file.workspace = true
publish.workspace = true
[dependencies]
axum-server.workspace = true
eyre.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tracing.workspace = true
[dev-dependencies]
nix.workspace = true
[lints]
workspace = true

View File

@@ -0,0 +1,480 @@
//! Utilities to manage long running tasks, such as servers and watchers, and events propagated
//! between those tasks.
//!
//! Also manages signals sent to the main process.
use std::{net, os::unix, sync::Arc, time::Duration};
use axum_server::Handle;
use eyre::{Report, Result};
use tokio::{
signal::unix::{Signal, SignalKind, signal},
sync::{Mutex, broadcast},
task::{JoinSet, join_set::Builder},
};
use tokio_util::sync::{CancellationToken, WaitForCancellationFuture};
use tracing::info;
/// All the signal streams we watch for. We don't create those directly in [`watch_signals`]
/// because that would prevent us from handling errors early.
struct SignalStreams {
hup: Signal,
int: Signal,
quit: Signal,
usr1: Signal,
usr2: Signal,
term: Signal,
}
impl SignalStreams {
fn new() -> Result<Self> {
Ok(Self {
hup: signal(SignalKind::hangup())?,
int: signal(SignalKind::interrupt())?,
quit: signal(SignalKind::quit())?,
usr1: signal(SignalKind::user_defined1())?,
usr2: signal(SignalKind::user_defined2())?,
term: signal(SignalKind::terminate())?,
})
}
}
/// Watch for incoming signals and either shutdown the application or dispatch them to receivers.
async fn watch_signals(streams: SignalStreams, arbiter: Arbiter) -> Result<()> {
info!("starting signals watcher");
let SignalStreams {
mut hup,
mut int,
mut quit,
mut usr1,
mut usr2,
mut term,
} = streams;
loop {
tokio::select! {
_ = hup.recv() => {
info!("signal HUP received");
arbiter.do_fast_shutdown().await;
},
_ = int.recv() => {
info!("signal INT received");
arbiter.do_fast_shutdown().await;
},
_ = quit.recv() => {
info!("signal QUIT received");
arbiter.do_fast_shutdown().await;
},
_ = term.recv() => {
info!("signal TERM received");
arbiter.do_graceful_shutdown().await;
},
_ = usr1.recv() => {
info!("signal URS1 received");
let _ = arbiter.send_event(SignalKind::user_defined1().into());
},
_ = usr2.recv() => {
info!("USR2 received.");
let _ = arbiter.send_event(SignalKind::user_defined2().into());
},
() = arbiter.shutdown() => {
info!("stopping signals watcher");
return Ok(());
}
};
}
}
/// Manager for long running tasks, such as servers and watchers.
pub struct Tasks {
tasks: JoinSet<Result<()>>,
arbiter: Arbiter,
}
impl Tasks {
/// Create a new [`Tasks`] manager.
///
/// # Errors
///
/// Errors if the creation of signals watcher fails.
pub fn new() -> Result<Self> {
let mut tasks = JoinSet::new();
let arbiter = Arbiter::new(&mut tasks)?;
Ok(Self { tasks, arbiter })
}
/// Build a new task. See [`tokio::task::JoinSet::build_task`] for details.
pub fn build_task(&mut self) -> Builder<'_, Result<()>> {
self.tasks.build_task()
}
/// Get an [`Arbiter`].
pub fn arbiter(&self) -> Arbiter {
self.arbiter.clone()
}
/// Run the tasks until completion. If one of them fails, terminate the program immediately.
pub async fn run(self) -> Vec<Report> {
let Self { mut tasks, arbiter } = self;
let mut errors = Vec::new();
if let Some(result) = tasks.join_next().await {
arbiter.do_graceful_shutdown().await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => {
arbiter.do_fast_shutdown().await;
errors.push(err);
}
Err(err) => {
arbiter.do_fast_shutdown().await;
errors.push(Report::new(err));
}
}
while let Some(result) = tasks.join_next().await {
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => errors.push(err),
Err(err) => errors.push(Report::new(err)),
}
}
}
errors
}
}
/// Manage shutdown state and several communication channels.
#[derive(Clone)]
pub struct Arbiter {
/// Token to shutdown the application immediately.
fast_shutdown: CancellationToken,
/// Token to shutdown the application gracefully.
graceful_shutdown: CancellationToken,
/// Token set when any shutdown is triggered.
shutdown: CancellationToken,
/// axum-server [`Handle`] to manage.
net_handles: Arc<Mutex<Vec<Handle<net::SocketAddr>>>>,
unix_handles: Arc<Mutex<Vec<Handle<unix::net::SocketAddr>>>>,
/// Broadcaster of program-wide events, except shutdown which is handled by tokens above.
events_tx: broadcast::Sender<Event>,
}
impl Arbiter {
fn new(tasks: &mut JoinSet<Result<()>>) -> Result<Self> {
let (events_tx, _events_rx) = broadcast::channel(1024);
let arbiter = Self {
fast_shutdown: CancellationToken::new(),
graceful_shutdown: CancellationToken::new(),
shutdown: CancellationToken::new(),
// 5 is http, https, metrics and a bit of room
net_handles: Arc::new(Mutex::new(Vec::with_capacity(5))),
// 2 is http and metrics
unix_handles: Arc::new(Mutex::new(Vec::with_capacity(2))),
events_tx,
};
let streams = SignalStreams::new()?;
tasks
.build_task()
.name(&format!("{}::watch_signals", module_path!()))
.spawn(watch_signals(streams, arbiter.clone()))?;
Ok(arbiter)
}
/// Add a new [`Handle`] to be managed, specifically for [`net::SocketAddr`] addresses.
///
/// This handle will be shutdown when this arbiter is shutdown.
pub async fn add_net_handle(&self, handle: Handle<net::SocketAddr>) {
self.net_handles.lock().await.push(handle);
}
/// Add a new [`Handle`] to be managed, specifically for [`unix::net::SocketAddr`] addresses.
///
/// This handle will be shutdown when this arbiter is shutdown.
pub async fn add_unix_handle(&self, handle: Handle<unix::net::SocketAddr>) {
self.unix_handles.lock().await.push(handle);
}
/// Future that will complete when the application needs to shutdown immediately.
///
/// Consumers listening on this must also listen on [`Arbiter::graceful_shutdown`], as only one
/// of those is set upon shutdown.
///
/// It is also possible to use [`Arbiter::shutdown`] when the behaviour is the same between a
/// fast and a graceful shutdown.
pub fn fast_shutdown(&self) -> WaitForCancellationFuture<'_> {
self.fast_shutdown.cancelled()
}
/// Future that will complete when the application needs to shutdown gracefully.
///
/// Consumers listening on this must also listen on [`Arbiter::fast_shutdown`], as only one
/// of those is set upon shutdown.
///
/// It is also possible to use [`Arbiter::shutdown`] when the behaviour is the same between a
/// fast and a graceful shutdown.
pub fn graceful_shutdown(&self) -> WaitForCancellationFuture<'_> {
self.graceful_shutdown.cancelled()
}
/// Future that will complete when the application needs to shutdown, either immediately or
/// gracefully. It's a helper so users that don't make the difference between immediate and
/// graceful shutdown don't need to handle two scenarios.
pub fn shutdown(&self) -> WaitForCancellationFuture<'_> {
self.shutdown.cancelled()
}
/// Shutdown the application immediately.
async fn do_fast_shutdown(&self) {
info!("arbiter has been told to shutdown immediately");
self.unix_handles
.lock()
.await
.iter()
.for_each(Handle::shutdown);
self.net_handles
.lock()
.await
.iter()
.for_each(Handle::shutdown);
info!("all webservers have been shutdown, shutting down the other tasks immediately");
self.fast_shutdown.cancel();
self.shutdown.cancel();
}
/// Shutdown the application gracefully.
async fn do_graceful_shutdown(&self) {
info!("arbiter has been told to shutdown gracefully");
// Match the value in lifecycle/gunicorn.conf.py for graceful shutdown
let timeout = Some(Duration::from_secs(30 + 5));
self.unix_handles
.lock()
.await
.iter()
.for_each(|handle| handle.graceful_shutdown(timeout));
self.net_handles
.lock()
.await
.iter()
.for_each(|handle| handle.graceful_shutdown(timeout));
info!("all webservers have been shutdown, shutting down the other tasks gracefully");
self.graceful_shutdown.cancel();
self.shutdown.cancel();
}
/// Create a new [`broadcast::Receiver`] to listen for signals sent to the main process. This
/// may not include all signals we catch, since some of those will shutdown the application.
pub fn events_subscribe(&self) -> broadcast::Receiver<Event> {
self.events_tx.subscribe()
}
/// Send a value on the config changes watch channel.
///
/// # Errors
///
/// See [`broadcast::Sender::send`].
pub fn send_event(&self, value: Event) -> Result<usize, broadcast::error::SendError<Event>> {
self.events_tx.send(value)
}
}
/// Events propagated throughout the program.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Event {
/// A signal has been received.
Signal(SignalKind),
#[cfg(test)]
Noop,
}
impl From<SignalKind> for Event {
fn from(value: SignalKind) -> Self {
Self::Signal(value)
}
}
#[cfg(test)]
mod tests {
mod events {
use nix::sys::signal::{Signal, raise};
use super::super::*;
async fn signal_self(signal: Signal) {
raise(signal).expect("failed to send signal");
tokio::time::sleep(Duration::from_millis(50)).await;
}
#[tokio::test]
async fn signals_hup() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
signal_self(Signal::SIGHUP).await;
assert!(arbiter.fast_shutdown.is_cancelled());
assert!(!arbiter.graceful_shutdown.is_cancelled());
assert!(arbiter.shutdown.is_cancelled());
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn signals_quit() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
signal_self(Signal::SIGQUIT).await;
assert!(arbiter.fast_shutdown.is_cancelled());
assert!(!arbiter.graceful_shutdown.is_cancelled());
assert!(arbiter.shutdown.is_cancelled());
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn signals_int() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
signal_self(Signal::SIGINT).await;
assert!(arbiter.fast_shutdown.is_cancelled());
assert!(!arbiter.graceful_shutdown.is_cancelled());
assert!(arbiter.shutdown.is_cancelled());
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn signals_term() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
signal_self(Signal::SIGTERM).await;
assert!(!arbiter.fast_shutdown.is_cancelled());
assert!(arbiter.graceful_shutdown.is_cancelled());
assert!(arbiter.shutdown.is_cancelled());
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn signals_other_no_listener() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
signal_self(Signal::SIGUSR1).await;
signal_self(Signal::SIGUSR2).await;
arbiter.do_fast_shutdown().await;
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn signals_usr1() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
let mut events_rx = arbiter.events_subscribe();
signal_self(Signal::SIGUSR1).await;
assert_eq!(
events_rx.recv().await.expect("failed to receive event"),
Event::Signal(SignalKind::user_defined1())
);
}
#[tokio::test]
async fn signals_usr2() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
let mut events_rx = arbiter.events_subscribe();
signal_self(Signal::SIGUSR2).await;
assert_eq!(
events_rx.recv().await.expect("failed to receive event"),
Event::Signal(SignalKind::user_defined2()),
);
}
#[tokio::test]
async fn events() {
let tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
let mut events_rx1 = arbiter.events_subscribe();
let mut events_rx2 = arbiter.events_subscribe();
let _ = arbiter.send_event(Event::Noop);
assert_eq!(
events_rx1.recv().await.expect("failed to receive event"),
Event::Noop,
);
assert_eq!(
events_rx2.recv().await.expect("failed to receive event"),
Event::Noop,
);
}
}
mod tasks {
use eyre::eyre;
use super::super::*;
async fn success_task(arbiter: Arbiter) -> Result<()> {
tokio::select! {
() = arbiter.fast_shutdown() => {},
() = arbiter.graceful_shutdown() => {},
}
Ok(())
}
async fn error_task(arbiter: Arbiter) -> Result<()> {
arbiter.shutdown().await;
Err(eyre!("error"))
}
#[tokio::test]
async fn successful_tasks() {
let mut tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
for _ in 0..10_u8 {
tasks
.build_task()
.spawn(success_task(arbiter.clone()))
.expect("failed to spawn task");
}
arbiter.do_fast_shutdown().await;
assert_eq!(tasks.run().await.len(), 0);
}
#[tokio::test]
async fn error_tasks() {
let mut tasks = Tasks::new().expect("tasks to create successfully");
let arbiter = tasks.arbiter();
for _ in 0..10_u8 {
tasks
.build_task()
.spawn(error_task(arbiter.clone()))
.expect("failed to spawn task");
}
arbiter.do_fast_shutdown().await;
assert_eq!(tasks.run().await.len(), 10);
}
}
}

View File

@@ -1,5 +1,8 @@
//! Various utilities used by other crates
pub mod arbiter;
pub use arbiter::{Arbiter, Event, Tasks};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn authentik_build_hash(fallback: Option<String>) -> String {

View File

@@ -22,9 +22,3 @@ pub(crate) fn add_extra_dot_dot_to_expression_mdx(migrate_path: &Path) {
let _ = write(file, content.replace("../expressions", "../../expressions"));
}
}
#[cfg(test)]
mod tests {
#[test]
fn noop() {}
}