packages/ak-common/db: init (#21357)

This commit is contained in:
Marc 'risson' Schmitt
2026-04-09 11:57:44 +00:00
committed by GitHub
parent dedbbee55c
commit 0dbd6a68b6
6 changed files with 1016 additions and 11 deletions

View File

@@ -329,7 +329,7 @@ jobs:
- name: Setup authentik env
uses: ./.github/actions/setup
with:
dependencies: rust
dependencies: rust,runtime
- name: run tests
run: |
cargo llvm-cov --no-report nextest --workspace

789
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -75,6 +75,17 @@ serde_repr = "= 0.1.20"
serde_with = { version = "= 3.18.0", default-features = false, features = [
"base64",
] }
sqlx = { version = "= 0.8.6", default-features = false, features = [
"runtime-tokio",
"tls-rustls-aws-lc-rs",
"postgres",
"derive",
"macros",
"uuid",
"chrono",
"ipnet",
"json",
] }
tempfile = "= 3.27.0"
thiserror = "= 2.0.18"
time = { version = "= 0.3.47", features = ["macros"] }

View File

@@ -11,7 +11,7 @@ publish.workspace = true
[features]
default = ["core", "proxy"]
core = []
core = ["dep:sqlx"]
proxy = []
[dependencies]
@@ -30,6 +30,7 @@ rustls.workspace = true
sentry.workspace = true
serde.workspace = true
serde_json.workspace = true
sqlx = { workspace = true, optional = true }
thiserror.workspace = true
time.workspace = true
tokio-util.workspace = true

View File

@@ -0,0 +1,220 @@
use std::{str::FromStr as _, sync::OnceLock, time::Duration};
use eyre::Result;
use sqlx::{
ConnectOptions as _, Executor as _, PgConnection, PgPool,
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
};
use tracing::{info, log::LevelFilter, trace};
use crate::{
arbiter::{Arbiter, Event, Tasks},
authentik_full_version, config,
mode::Mode,
};
static DB: OnceLock<PgPool> = OnceLock::new();
fn get_connect_opts() -> Result<PgConnectOptions> {
let config = config::get();
let mut opts = PgConnectOptions::new()
.application_name(&format!(
"authentik-{}@{}",
Mode::get(),
authentik_full_version()
))
.host(&config.postgresql.host)
.port(config.postgresql.port)
.username(&config.postgresql.user)
.password(&config.postgresql.password)
.database(&config.postgresql.name)
.ssl_mode(PgSslMode::from_str(&config.postgresql.sslmode)?);
if let Some(sslrootcert) = &config.postgresql.sslrootcert {
opts = opts.ssl_root_cert_from_pem(sslrootcert.as_bytes().to_vec());
}
if let Some(sslcert) = &config.postgresql.sslcert {
opts = opts.ssl_client_cert_from_pem(sslcert.as_bytes());
}
if let Some(sslkey) = &config.postgresql.sslkey {
opts = opts.ssl_client_key_from_pem(sslkey.as_bytes());
}
Ok(opts)
}
async fn update_connect_opts_on_config_change(arbiter: Arbiter) -> Result<()> {
let mut events_rx = arbiter.events_subscribe();
info!("starting database watcher for config changes");
loop {
tokio::select! {
Ok(Event::ConfigChanged) = events_rx.recv() => {
trace!("config change received, refreshing database connection options");
let db = get();
db.set_connect_options(get_connect_opts()?);
},
() = arbiter.shutdown() => {
info!("stopping database watcher for config changes");
return Ok(());
},
}
}
}
pub async fn init(tasks: &mut Tasks) -> Result<()> {
info!("initializing database pool");
let options = get_connect_opts()?;
let config = config::get();
let pool_options = PgPoolOptions::new()
.min_connections(1)
.max_connections(4)
.acquire_time_level(LevelFilter::Trace)
.max_lifetime(config.postgresql.conn_max_age.map(Duration::from_secs))
.test_before_acquire(config.postgresql.conn_health_checks)
.after_connect(|conn, _meta| {
Box::pin(async move {
let application_name =
format!("authentik-{}@{}", Mode::get(), authentik_full_version());
let default_schema = &config::get().postgresql.default_schema;
let query = format!(
"SET application_name = '{application_name}'; SET search_path = \
'{default_schema}';"
);
conn.execute(query.as_str()).await?;
Ok(())
})
});
let pool = pool_options.connect_with(options).await?;
DB.get_or_init(|| pool);
let arbiter = tasks.arbiter();
tasks
.build_task()
.name(&format!(
"{}::update_connect_opts_on_config_change",
module_path!(),
))
.spawn(update_connect_opts_on_config_change(arbiter))?;
info!("database pool initialized");
Ok(())
}
pub fn get() -> &'static PgPool {
DB.get()
.expect("failed to get db, has it been initialized?")
}
pub async fn create_conn() -> Result<PgConnection> {
let options = get_connect_opts()?;
let conn = options.connect().await?;
Ok(conn)
}
#[cfg(test)]
mod tests {
use serde_json::json;
use sqlx::postgres::PgSslMode;
use tokio::time::{Duration, sleep};
use crate::{
arbiter::{Event, Tasks},
config,
};
#[tokio::test]
async fn init() {
std::env::set_current_dir(format!("{}/../../", env!("CARGO_MANIFEST_DIR")))
.expect("failed to chdir");
config::init().expect("failed to init config");
let mut tasks = Tasks::new().expect("failed to create tasks");
super::init(&mut tasks).await.expect("failed to init db");
}
#[tokio::test]
async fn get() {
std::env::set_current_dir(format!("{}/../../", env!("CARGO_MANIFEST_DIR")))
.expect("failed to chdir");
config::init().expect("failed to init config");
let mut tasks = Tasks::new().expect("failed to create tasks");
super::init(&mut tasks).await.expect("failed to init db");
sqlx::query("SELECT 1")
.execute(super::get())
.await
.expect("failed to execute query");
}
#[tokio::test]
async fn conn_options() {
std::env::set_current_dir(format!("{}/../../", env!("CARGO_MANIFEST_DIR")))
.expect("failed to chdir");
config::init().expect("failed to init config");
let mut tasks = Tasks::new().expect("failed to create tasks");
super::init(&mut tasks).await.expect("failed to init db");
assert_eq!(config::get().postgresql.default_schema, "public");
let row: (String,) = sqlx::query_as("SHOW search_path")
.fetch_one(super::get())
.await
.expect("failed to run query");
assert_eq!(row.0, "public");
let row: (String,) = sqlx::query_as("SHOW application_name")
.fetch_one(super::get())
.await
.expect("failed to run query");
assert!(row.0.contains("authentik"));
}
#[tokio::test]
async fn config_update() {
std::env::set_current_dir(format!("{}/../../", env!("CARGO_MANIFEST_DIR")))
.expect("failed to chdir");
config::init().expect("failed to init config");
let mut tasks = Tasks::new().expect("failed to create tasks");
let arbiter = tasks.arbiter();
super::init(&mut tasks).await.expect("failed to init db");
// Wait for the background tasks to start.
sleep(Duration::from_millis(100)).await;
assert!(matches!(
super::get().connect_options().get_ssl_mode(),
PgSslMode::Disable
));
config::set(json!({
"postgresql": {
"sslmode": "prefer",
},
}))
.expect("failed to set config");
arbiter
.send_event(Event::ConfigChanged)
.expect("failed to send config changed event");
// Wait for the change to propagate.
sleep(Duration::from_millis(100)).await;
assert!(matches!(
super::get().connect_options().get_ssl_mode(),
PgSslMode::Prefer
));
}
#[tokio::test]
async fn create_conn() {
std::env::set_current_dir(format!("{}/../../", env!("CARGO_MANIFEST_DIR")))
.expect("failed to chdir");
config::init().expect("failed to init config");
let mut conn = super::create_conn().await.expect("failed to create conn");
sqlx::query("SELECT 1")
.execute(&mut conn)
.await
.expect("failed to run query");
}
}

View File

@@ -3,6 +3,8 @@
pub mod arbiter;
pub use arbiter::{Arbiter, Event, Tasks};
pub mod config;
#[cfg(feature = "core")]
pub mod db;
pub mod mode;
pub use mode::Mode;
pub mod tls;