mirror of
https://github.com/goauthentik/authentik
synced 2026-04-25 17:15:26 +02:00
@@ -150,6 +150,8 @@ perf = { priority = -1, level = "warn" }
|
||||
style = { priority = -1, level = "warn" }
|
||||
suspicious = { priority = -1, level = "warn" }
|
||||
### and disable the ones we don't want
|
||||
### cargo group
|
||||
multiple_crate_versions = "allow"
|
||||
### pedantic group
|
||||
redundant_closure_for_method_calls = "allow"
|
||||
struct_field_names = "allow"
|
||||
@@ -170,7 +172,6 @@ create_dir = "warn"
|
||||
dbg_macro = "warn"
|
||||
default_numeric_fallback = "warn"
|
||||
disallowed_script_idents = "warn"
|
||||
doc_paragraphs_missing_punctuation = "warn"
|
||||
empty_drop = "warn"
|
||||
empty_enum_variants_with_brackets = "warn"
|
||||
empty_structs_with_brackets = "warn"
|
||||
|
||||
@@ -4,7 +4,7 @@ use axum::{Extension, middleware::AddExtension};
|
||||
use axum_server::accept::{Accept, DefaultAcceptor};
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tower::Layer;
|
||||
use tower::Layer as _;
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};
|
||||
|
||||
@@ -3,7 +3,7 @@ use axum_server::{accept::Accept, tls_rustls::RustlsAcceptor};
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::{rustls::pki_types::CertificateDer, server::TlsStream};
|
||||
use tower::Layer;
|
||||
use tower::Layer as _;
|
||||
use tracing::instrument;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
|
||||
use axum::{
|
||||
Extension, RequestPartsExt,
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{ConnectInfo, FromRequestParts, Request},
|
||||
http::request::Parts,
|
||||
middleware::Next,
|
||||
@@ -81,7 +81,7 @@ pub(crate) async fn client_ip_middleware(request: Request, next: Next) -> Respon
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{Ipv4Addr, Ipv6Addr};
|
||||
use std::net::Ipv4Addr;
|
||||
|
||||
use axum::{body::Body, http::Request};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use axum::{
|
||||
Extension, RequestPartsExt,
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{FromRequestParts, Request},
|
||||
http::{
|
||||
header::{FORWARDED, HOST},
|
||||
@@ -7,7 +7,7 @@ use axum::{
|
||||
status::StatusCode,
|
||||
},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
response::{IntoResponse as _, Response},
|
||||
};
|
||||
use forwarded_header_value::ForwardedHeaderValue;
|
||||
use tracing::{Span, instrument};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use axum::{
|
||||
Extension, RequestPartsExt,
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{FromRequestParts, Request},
|
||||
http::{self, header::FORWARDED, request::Parts},
|
||||
middleware::Next,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{
|
||||
Extension, RequestPartsExt,
|
||||
Extension, RequestPartsExt as _,
|
||||
extract::{ConnectInfo, FromRequestParts, Request},
|
||||
http::request::Parts,
|
||||
middleware::Next,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::HashMap;
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::Response};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{Instrument, field, info, info_span, trace};
|
||||
use tracing::{Instrument as _, field, info, info_span, trace};
|
||||
|
||||
use crate::config;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use eyre::{Report, Result};
|
||||
use rustls::{
|
||||
RootCertStore,
|
||||
crypto::CryptoProvider,
|
||||
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
|
||||
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
|
||||
server::ClientHello,
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::{
|
||||
|
||||
use arc_swap::{ArcSwap, Guard};
|
||||
use eyre::Result;
|
||||
use notify::{RecommendedWatcher, Watcher};
|
||||
use notify::{RecommendedWatcher, Watcher as _};
|
||||
use serde_json::{Map, Value};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{error, info, warn};
|
||||
@@ -86,7 +86,7 @@ impl Config {
|
||||
"file" => {
|
||||
let path = uri.path();
|
||||
match read_to_string(path).map(|s| s.trim().to_owned()) {
|
||||
Ok(value) => return (value.to_owned(), Some(PathBuf::from(path))),
|
||||
Ok(value) => return (value, Some(PathBuf::from(path))),
|
||||
Err(err) => {
|
||||
error!("failed to read config value from {path}: {err}");
|
||||
return (fallback, Some(PathBuf::from(path)));
|
||||
@@ -96,7 +96,7 @@ impl Config {
|
||||
"env" => {
|
||||
if let Some(var) = uri.host_str() {
|
||||
if let Ok(value) = env::var(var) {
|
||||
return (value.to_owned(), None);
|
||||
return (value, None);
|
||||
}
|
||||
return (fallback, None);
|
||||
}
|
||||
@@ -141,10 +141,10 @@ impl Config {
|
||||
(value, file_paths)
|
||||
}
|
||||
|
||||
fn load(config_paths: &[PathBuf]) -> Result<(Config, Vec<PathBuf>)> {
|
||||
fn load(config_paths: &[PathBuf]) -> Result<(Self, Vec<PathBuf>)> {
|
||||
let raw = Self::load_raw(config_paths)?;
|
||||
let (expanded, file_paths) = Self::expand(raw);
|
||||
let config: Config = serde_json::from_value(expanded)?;
|
||||
let config: Self = serde_json::from_value(expanded)?;
|
||||
Ok((config, file_paths))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::{str::FromStr, sync::OnceLock, time::Duration};
|
||||
use std::{str::FromStr as _, sync::OnceLock, time::Duration};
|
||||
|
||||
use eyre::Result;
|
||||
use sqlx::{
|
||||
Executor, PgPool,
|
||||
Executor as _, PgPool,
|
||||
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
|
||||
};
|
||||
use tracing::{info, log::LevelFilter, trace};
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
process::exit,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
use ::tracing::{error, info, trace};
|
||||
|
||||
@@ -50,7 +50,7 @@ pub(super) async fn metrics_handler(State(state): State<Arc<Metrics>>) -> Result
|
||||
mod python {
|
||||
use eyre::{Report, Result};
|
||||
use pyo3::{
|
||||
IntoPyObjectExt,
|
||||
IntoPyObjectExt as _,
|
||||
ffi::c_str,
|
||||
prelude::*,
|
||||
types::{PyBytes, PyDict},
|
||||
@@ -81,7 +81,7 @@ output = generate_latest(registry)
|
||||
.get_item("output")?
|
||||
.unwrap_or(PyBytes::new(py, &[]).into_bound_py_any(py)?)
|
||||
.cast::<PyBytes>()
|
||||
.unwrap_or(&PyBytes::new(py, &[]))
|
||||
.map_or_else(|_| PyBytes::new(py, &[]), |v| v.to_owned())
|
||||
.as_bytes()
|
||||
.to_owned();
|
||||
Ok::<_, Report>(metrics)
|
||||
|
||||
@@ -42,7 +42,7 @@ async fn run_upkeep(arbiter: Arbiter, state: Arc<Metrics>) -> Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(5)) => {
|
||||
let state_clone = state.clone();
|
||||
let state_clone = Arc::clone(&state);
|
||||
tokio::task::spawn_blocking(move || state_clone.prometheus.run_upkeep()).await?;
|
||||
},
|
||||
() = arbiter.shutdown() => return Ok(())
|
||||
@@ -62,12 +62,12 @@ fn build_router(state: Arc<Metrics>) -> Router {
|
||||
pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
|
||||
let arbiter = tasks.arbiter();
|
||||
let metrics = Arc::new(Metrics::new()?);
|
||||
let router = build_router(metrics.clone());
|
||||
let router = build_router(Arc::clone(&metrics));
|
||||
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::run_upkeep", module_path!(),))
|
||||
.spawn(run_upkeep(arbiter.clone(), metrics.clone()))?;
|
||||
.spawn(run_upkeep(arbiter, Arc::clone(&metrics)))?;
|
||||
|
||||
for addr in config::get().listen.metrics.iter().copied() {
|
||||
server::start_plain(tasks, "metrics", router.clone(), addr)?;
|
||||
@@ -77,11 +77,7 @@ pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
|
||||
tasks,
|
||||
"metrics",
|
||||
router,
|
||||
unix::net::SocketAddr::from_pathname({
|
||||
let mut path = temp_dir();
|
||||
path.push("authentik-metrics.sock");
|
||||
path
|
||||
})?,
|
||||
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik-metrics.sock"))?,
|
||||
)?;
|
||||
|
||||
Ok(metrics)
|
||||
|
||||
@@ -40,6 +40,13 @@ impl std::fmt::Display for Mode {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Mode> for u8 {
|
||||
#[expect(clippy::as_conversions, reason = "repr of enum is u8")]
|
||||
fn from(value: Mode) -> Self {
|
||||
value as Self
|
||||
}
|
||||
}
|
||||
|
||||
impl Mode {
|
||||
pub(crate) fn get() -> Self {
|
||||
match MODE.load(Ordering::Relaxed) {
|
||||
@@ -57,7 +64,7 @@ impl Mode {
|
||||
|
||||
pub(crate) fn set(mode: Self) -> Result<()> {
|
||||
std::fs::write(mode_path(), mode.to_string())?;
|
||||
MODE.store(mode as u8, Ordering::SeqCst);
|
||||
MODE.store(mode.into(), Ordering::SeqCst);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ use crate::arbiter::{Arbiter, Tasks};
|
||||
#[derive(Debug, FromArgs, PartialEq)]
|
||||
/// Run the authentik proxy outpost.
|
||||
#[argh(subcommand, name = "proxy")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(crate) struct Cli {}
|
||||
|
||||
pub(crate) mod tls {
|
||||
@@ -15,9 +19,10 @@ pub(crate) mod tls {
|
||||
use rustls::{server::ClientHello, sign::CertifiedKey};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct CertResolver {}
|
||||
pub(crate) struct CertResolver;
|
||||
|
||||
impl CertResolver {
|
||||
#[expect(clippy::unused_self, reason = "still WIP")]
|
||||
pub(crate) fn resolve(&self, _client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use axum::{
|
||||
response::{IntoResponse, Response},
|
||||
routing::any,
|
||||
};
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::BodyExt as _;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
@@ -226,7 +226,7 @@ pub(super) fn build_router(server: Arc<Server>) -> Router {
|
||||
Router::new()
|
||||
.route("/-/metrics/", any((StatusCode::NOT_FOUND, "not found")))
|
||||
.route("/-/health/ready/", any(health_ready))
|
||||
.with_state(server.clone())
|
||||
.with_state(Arc::clone(&server))
|
||||
.merge(super::r#static::build_router()),
|
||||
true,
|
||||
)
|
||||
@@ -264,9 +264,9 @@ mod websockets {
|
||||
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
|
||||
},
|
||||
},
|
||||
response::{IntoResponse, Response},
|
||||
response::{IntoResponse as _, Response},
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use futures::{SinkExt as _, StreamExt as _};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::{net::UnixStream, sync::mpsc};
|
||||
use tokio_tungstenite::{
|
||||
@@ -373,7 +373,7 @@ mod websockets {
|
||||
Message::Close(_) => {
|
||||
if !client_closed {
|
||||
upstream_sender.send(Message::Close(None)).await?;
|
||||
close_tx.send(()).await.ok();
|
||||
let _ = close_tx.send(()).await;
|
||||
client_closed = true;
|
||||
break;
|
||||
}
|
||||
@@ -391,7 +391,7 @@ mod websockets {
|
||||
}
|
||||
if !client_closed {
|
||||
upstream_sender.send(Message::Close(None)).await?;
|
||||
close_tx.send(()).await.ok();
|
||||
let _ = close_tx.send(()).await;
|
||||
}
|
||||
Ok::<_, AppError>(())
|
||||
});
|
||||
@@ -404,7 +404,7 @@ mod websockets {
|
||||
Message::Close(_) => {
|
||||
if !upstream_closed {
|
||||
client_sender.send(Message::Close(None)).await?;
|
||||
close_tx_upstream.send(()).await.ok();
|
||||
let _ = close_tx_upstream.send(()).await;
|
||||
upstream_closed = true;
|
||||
break;
|
||||
}
|
||||
@@ -422,7 +422,7 @@ mod websockets {
|
||||
}
|
||||
if !upstream_closed {
|
||||
client_sender.send(Message::Close(None)).await?;
|
||||
close_tx_upstream.send(()).await.ok();
|
||||
let _ = close_tx_upstream.send(()).await;
|
||||
}
|
||||
Ok::<_, AppError>(())
|
||||
});
|
||||
|
||||
@@ -27,7 +27,7 @@ use tokio::{
|
||||
sync::{Mutex, broadcast::error::RecvError},
|
||||
time::Instant,
|
||||
};
|
||||
use tower::ServiceExt;
|
||||
use tower::ServiceExt as _;
|
||||
use tower_http::timeout::TimeoutLayer;
|
||||
use tracing::{info, trace, warn};
|
||||
|
||||
@@ -47,6 +47,10 @@ mod tls;
|
||||
#[derive(Debug, Default, FromArgs, PartialEq)]
|
||||
/// Run the authentik server.
|
||||
#[argh(subcommand, name = "server")]
|
||||
#[expect(
|
||||
clippy::empty_structs_with_brackets,
|
||||
reason = "argh doesn't support unit structs"
|
||||
)]
|
||||
pub(super) struct Cli {}
|
||||
|
||||
pub(crate) struct Server {
|
||||
@@ -90,10 +94,12 @@ impl Server {
|
||||
signal = signal.as_str(),
|
||||
"sending shutdown signal to gunicorn"
|
||||
);
|
||||
if let Some(id) = self.gunicorn.lock().await.id() {
|
||||
let mut gunicorn = self.gunicorn.lock().await;
|
||||
if let Some(id) = gunicorn.id() {
|
||||
kill(Pid::from_raw(id.cast_signed()), signal)?;
|
||||
}
|
||||
self.gunicorn.lock().await.wait().await?;
|
||||
gunicorn.wait().await?;
|
||||
drop(gunicorn);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -221,9 +227,9 @@ pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
|
||||
tasks
|
||||
.build_task()
|
||||
.name(&format!("{}::watch_server", module_path!()))
|
||||
.spawn(watch_server(arbiter.clone(), server.clone()))?;
|
||||
.spawn(watch_server(arbiter.clone(), Arc::clone(&server)))?;
|
||||
|
||||
let router = build_router(server.clone());
|
||||
let router = build_router(Arc::clone(&server));
|
||||
|
||||
for addr in config.listen.http.iter().copied() {
|
||||
server::start_plain(tasks, "server", router.clone(), addr)?;
|
||||
@@ -242,11 +248,7 @@ pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
|
||||
tasks,
|
||||
"server",
|
||||
router,
|
||||
unix::net::SocketAddr::from_pathname({
|
||||
let mut path = temp_dir();
|
||||
path.push("authentik.sock");
|
||||
path
|
||||
})?,
|
||||
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
|
||||
)?;
|
||||
|
||||
Ok(server)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::fmt::Write as _;
|
||||
|
||||
use aws_lc_rs::digest;
|
||||
use axum::{
|
||||
Router,
|
||||
@@ -7,7 +9,7 @@ use axum::{
|
||||
header::{CACHE_CONTROL, CONTENT_SECURITY_POLICY, VARY},
|
||||
},
|
||||
middleware::{self, Next},
|
||||
response::{IntoResponse, Response},
|
||||
response::{IntoResponse as _, Response},
|
||||
routing::any,
|
||||
};
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
|
||||
@@ -58,8 +60,10 @@ fn is_storage_token_valid(usage: &str, secret_key: &str, request: &Request) -> b
|
||||
let key_hex_digest = key_digest
|
||||
.as_ref()
|
||||
.iter()
|
||||
.map(|b| format!("{b:02x}"))
|
||||
.collect::<String>();
|
||||
.fold(String::new(), |mut acc, b| {
|
||||
let _ = write!(acc, "{b:02x}");
|
||||
acc
|
||||
});
|
||||
|
||||
let mut validation = Validation::new(token_header.alg);
|
||||
validation.validate_exp = false;
|
||||
@@ -157,12 +161,12 @@ pub(crate) fn build_router() -> Router {
|
||||
let static_fs = ServeDir::new("./web/authentik/").append_index_html_on_directories(false);
|
||||
|
||||
router = router.nest_service("/static/dist/", dist_fs.clone());
|
||||
router = router.nest_service("/static/authentik/", static_fs.clone());
|
||||
router = router.nest_service("/static/authentik/", static_fs);
|
||||
|
||||
router = router.nest_service("/if/flow/{flow_slug}/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/admin/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/user/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs.clone());
|
||||
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs);
|
||||
|
||||
let default_backend = &config.storage.backend;
|
||||
let media_backend = config
|
||||
@@ -171,19 +175,19 @@ pub(crate) fn build_router() -> Router {
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.backend
|
||||
.unwrap_or(default_backend.clone());
|
||||
.unwrap_or_else(|| default_backend.clone());
|
||||
let reports_backend = config
|
||||
.storage
|
||||
.reports
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.backend
|
||||
.unwrap_or(default_backend.clone());
|
||||
.unwrap_or_else(|| default_backend.clone());
|
||||
|
||||
let default_path = &config.storage.file.path;
|
||||
|
||||
if media_backend == "file" {
|
||||
let mut media_path = config
|
||||
let media_path = config
|
||||
.storage
|
||||
.media
|
||||
.clone()
|
||||
@@ -191,8 +195,8 @@ pub(crate) fn build_router() -> Router {
|
||||
.file
|
||||
.unwrap_or_default()
|
||||
.path
|
||||
.unwrap_or(default_path.clone());
|
||||
media_path.push("media");
|
||||
.unwrap_or_else(|| default_path.clone())
|
||||
.join("media");
|
||||
|
||||
let media_fs = ServeDir::new(media_path).append_index_html_on_directories(false);
|
||||
let media_router =
|
||||
@@ -209,7 +213,7 @@ pub(crate) fn build_router() -> Router {
|
||||
}
|
||||
|
||||
if reports_backend == "file" {
|
||||
let mut reports_path = config
|
||||
let reports_path = config
|
||||
.storage
|
||||
.reports
|
||||
.clone()
|
||||
@@ -217,8 +221,8 @@ pub(crate) fn build_router() -> Router {
|
||||
.file
|
||||
.unwrap_or_default()
|
||||
.path
|
||||
.unwrap_or(default_path.clone());
|
||||
reports_path.push("reports");
|
||||
.unwrap_or_else(|| default_path.clone())
|
||||
.join("reports");
|
||||
|
||||
let reports_fs = ServeDir::new(reports_path).append_index_html_on_directories(false);
|
||||
let reports_router =
|
||||
|
||||
@@ -71,10 +71,14 @@ struct CertResolver {
|
||||
fallback: Arc<CertifiedKey>,
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::missing_trait_methods,
|
||||
reason = "the provided methods are sensible enough"
|
||||
)]
|
||||
impl ResolvesServerCert for CertResolver {
|
||||
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
if client_hello.server_name().is_none() {
|
||||
Some(self.fallback.clone())
|
||||
Some(Arc::clone(&self.fallback))
|
||||
} else if let Some(resolver) = &self.proxy_resolver
|
||||
&& let Some(cert) = resolver.resolve(&client_hello)
|
||||
{
|
||||
@@ -82,7 +86,7 @@ impl ResolvesServerCert for CertResolver {
|
||||
} else if let Some(cert) = self.core_resolver.resolve(&client_hello) {
|
||||
Some(cert)
|
||||
} else {
|
||||
Some(self.fallback.clone())
|
||||
Some(Arc::clone(&self.fallback))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,8 +52,12 @@ pub(crate) struct Tlvs<'a> {
|
||||
buf: &'a [u8],
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::missing_trait_methods,
|
||||
reason = "we don't need to implement the other methods here"
|
||||
)]
|
||||
impl<'a> Iterator for Tlvs<'a> {
|
||||
type Item = Result<Tlv<'a>, Error>;
|
||||
type Item = Result<Tlv<'a>, ProxyProtocolError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.buf.is_empty() {
|
||||
@@ -64,7 +68,7 @@ impl<'a> Iterator for Tlvs<'a> {
|
||||
match self
|
||||
.buf
|
||||
.get(1..3)
|
||||
.map(|s| u16::from_be_bytes(s.try_into().expect("infallible")) as usize)
|
||||
.map(|s| -> usize { u16::from_be_bytes(s.try_into().expect("infallible")).into() })
|
||||
{
|
||||
Some(u) if u + 3 <= self.buf.len() => {
|
||||
let (ret, new) = self.buf.split_at(3 + u);
|
||||
@@ -75,7 +79,7 @@ impl<'a> Iterator for Tlvs<'a> {
|
||||
_ => {
|
||||
// Malformed TLV, cannot continue
|
||||
self.buf = &[];
|
||||
Some(Err(Error::Invalid))
|
||||
Some(Err(ProxyProtocolError::Invalid))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -172,44 +176,58 @@ pub(crate) enum Tlv<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Tlv<'a> {
|
||||
fn decode(kind: u8, data: &'a [u8]) -> Result<Tlv<'a>, Error> {
|
||||
fn decode(kind: u8, data: &'a [u8]) -> Result<Self, ProxyProtocolError> {
|
||||
match kind {
|
||||
0x01 => Ok(Self::Alpn(data.into())),
|
||||
0x02 => Ok(Self::Authority(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x03 => Ok(Self::Crc32c(u32::from_be_bytes(
|
||||
data.try_into().map_err(|_| Error::Invalid)?,
|
||||
data.try_into().map_err(|_| ProxyProtocolError::Invalid)?,
|
||||
))),
|
||||
0x04 => Ok(Self::Noop(data.len())),
|
||||
0x05 => Ok(Self::UniqueId(data.into())),
|
||||
0x20 => Ok(Tlv::Ssl(SslInfo(
|
||||
*data.first().ok_or(Error::Invalid)?,
|
||||
*data.first().ok_or(ProxyProtocolError::Invalid)?,
|
||||
u32::from_be_bytes(
|
||||
data.get(1..5)
|
||||
.ok_or(Error::Invalid)?
|
||||
.ok_or(ProxyProtocolError::Invalid)?
|
||||
.try_into()
|
||||
.map_err(|_| Error::Invalid)?,
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?,
|
||||
),
|
||||
data.get(5..).ok_or(Error::Invalid)?.into(),
|
||||
data.get(5..).ok_or(ProxyProtocolError::Invalid)?.into(),
|
||||
))),
|
||||
0x21 => Ok(Self::SslVersion(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x22 => Ok(Self::SslCn(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x23 => Ok(Self::SslCipher(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x24 => Ok(Self::SslSigAlg(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x25 => Ok(Self::SslKeyAlg(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
0x30 => Ok(Self::Netns(
|
||||
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
|
||||
from_utf8(data)
|
||||
.map_err(|_| ProxyProtocolError::Invalid)?
|
||||
.into(),
|
||||
)),
|
||||
t => Ok(Self::Custom(t, data.into())),
|
||||
}
|
||||
@@ -346,20 +364,21 @@ impl<'a> Header<'a> {
|
||||
/// Attempt to parse a PROXY protocol header from the given buffer
|
||||
///
|
||||
/// Returns the parsed header and the number of bytes consumed from the buffer. If the header
|
||||
/// is incomplete, returns [`Error::BufferTooShort`] so more data can be read from the socket.
|
||||
/// is incomplete, returns [`ProxyProtocolError::BufferTooShort`] so more data can be read from
|
||||
/// the socket.
|
||||
///
|
||||
/// If the header is malformed or unsupported, returns [`Error::Invalid`].
|
||||
/// If the header is malformed or unsupported, returns [`ProxyProtocolError::Invalid`].
|
||||
///
|
||||
/// This function will borrow the buffer for the lifetime of the returned header. If
|
||||
/// you need to keep the header around for longer than the buffer, use
|
||||
/// [`Header::into_owned`].
|
||||
#[instrument(skip_all)]
|
||||
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), Error> {
|
||||
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), ProxyProtocolError> {
|
||||
match buf.first() {
|
||||
Some(b'P') => super::v1::decode(buf),
|
||||
Some(b'\r') => super::v2::decode(buf),
|
||||
None => Err(Error::BufferTooShort),
|
||||
_ => Err(Error::Invalid),
|
||||
None => Err(ProxyProtocolError::BufferTooShort),
|
||||
_ => Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +454,7 @@ impl<'a> Header<'a> {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Error)]
|
||||
pub(crate) enum Error {
|
||||
pub(crate) enum ProxyProtocolError {
|
||||
#[error("The buffer is too short to contain a complete PROXY protocol header")]
|
||||
BufferTooShort,
|
||||
#[error("The PROXY protocol header is malformed")]
|
||||
@@ -489,7 +508,7 @@ mod tests {
|
||||
for i in 0..case.len() {
|
||||
assert!(matches!(
|
||||
Header::parse(&case[..i]),
|
||||
Err(Error::BufferTooShort)
|
||||
Err(ProxyProtocolError::BufferTooShort)
|
||||
));
|
||||
}
|
||||
|
||||
@@ -500,7 +519,10 @@ mod tests {
|
||||
#[test]
|
||||
fn test_parse_proxy_header_v1_unterminated() {
|
||||
let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR";
|
||||
assert!(matches!(Header::parse(line), Err(Error::Invalid)));
|
||||
assert!(matches!(
|
||||
Header::parse(line),
|
||||
Err(ProxyProtocolError::Invalid)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -9,10 +9,10 @@ use std::{
|
||||
|
||||
use eyre::{Result, eyre};
|
||||
use pin_project_lite::pin_project;
|
||||
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
|
||||
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf};
|
||||
use tracing::instrument;
|
||||
|
||||
use crate::tokio::proxy_protocol::header::{Error, Header};
|
||||
use crate::tokio::proxy_protocol::header::{Header, ProxyProtocolError};
|
||||
|
||||
pub(crate) mod header;
|
||||
mod utils;
|
||||
@@ -67,8 +67,7 @@ impl<S> Deref for ProxyProtocolStream<S> {
|
||||
}
|
||||
|
||||
impl<S> ProxyProtocolStream<S>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
where S: AsyncRead + Unpin
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
pub(crate) async fn new(mut stream: S) -> Result<Self, io::Error> {
|
||||
@@ -94,7 +93,7 @@ where
|
||||
header: Some(header),
|
||||
});
|
||||
}
|
||||
Err(Error::BufferTooShort) => {}
|
||||
Err(ProxyProtocolError::BufferTooShort) => {}
|
||||
// Something went wrong parsing the PROXY protocol. We assume that we weren't meant
|
||||
// to parse it, and that this is just a regular stream without the PROXY protocol.
|
||||
Err(_) => {
|
||||
@@ -110,8 +109,7 @@ where
|
||||
}
|
||||
|
||||
impl<S> AsyncRead for ProxyProtocolStream<S>
|
||||
where
|
||||
S: AsyncRead,
|
||||
where S: AsyncRead
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
fn poll_read(
|
||||
@@ -135,8 +133,7 @@ where
|
||||
}
|
||||
|
||||
impl<S> AsyncBufRead for ProxyProtocolStream<S>
|
||||
where
|
||||
S: AsyncBufRead,
|
||||
where S: AsyncBufRead
|
||||
{
|
||||
#[instrument(skip_all)]
|
||||
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
|
||||
@@ -168,8 +165,7 @@ where
|
||||
}
|
||||
|
||||
impl<S> AsyncWrite for ProxyProtocolStream<S>
|
||||
where
|
||||
S: AsyncWrite,
|
||||
where S: AsyncWrite
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
|
||||
@@ -19,6 +19,7 @@ pub(super) trait AddressFamily: FromStr + Into<IpAddr> {
|
||||
}
|
||||
|
||||
impl AddressFamily for Ipv4Addr {
|
||||
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
|
||||
const BYTES: usize = (Self::BITS / 8) as usize;
|
||||
|
||||
fn from_slice(slice: &[u8]) -> Self {
|
||||
@@ -28,6 +29,7 @@ impl AddressFamily for Ipv4Addr {
|
||||
}
|
||||
|
||||
impl AddressFamily for Ipv6Addr {
|
||||
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
|
||||
const BYTES: usize = (Self::BITS / 8) as usize;
|
||||
|
||||
fn from_slice(slice: &[u8]) -> Self {
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
header::{Address, Error, Header, Protocol},
|
||||
header::{Address, Header, Protocol, ProxyProtocolError},
|
||||
utils::{AddressFamily, read_until},
|
||||
};
|
||||
|
||||
@@ -15,33 +15,36 @@ const UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
|
||||
// All other valid PROXY headers are longer than this
|
||||
const MIN_LENGTH: usize = UNKNOWN.len();
|
||||
|
||||
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, Error> {
|
||||
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, ProxyProtocolError> {
|
||||
let Some(address) = read_until(&buf[*pos..], b' ') else {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
|
||||
let addr = from_utf8(address)
|
||||
.map_err(|_| Error::Invalid)
|
||||
.and_then(|s| A::from_str(s).map_err(|_| Error::Invalid))?;
|
||||
.map_err(|_| ProxyProtocolError::Invalid)
|
||||
.and_then(|s| A::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
|
||||
*pos += address.len() + 1;
|
||||
|
||||
Ok(addr)
|
||||
}
|
||||
|
||||
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, Error> {
|
||||
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, ProxyProtocolError> {
|
||||
let Some(port) = read_until(&buf[*pos..], delim) else {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
|
||||
let p = from_utf8(port)
|
||||
.map_err(|_| Error::Invalid)
|
||||
.and_then(|s| u16::from_str(s).map_err(|_| Error::Invalid))?;
|
||||
.map_err(|_| ProxyProtocolError::Invalid)
|
||||
.and_then(|s| u16::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
|
||||
*pos += port.len() + 1;
|
||||
|
||||
Ok(p)
|
||||
}
|
||||
|
||||
fn parse_addrs<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<Address, Error> {
|
||||
fn parse_addrs<A: AddressFamily>(
|
||||
buf: &[u8],
|
||||
pos: &mut usize,
|
||||
) -> Result<Address, ProxyProtocolError> {
|
||||
let src_addr: A = parse_addr(buf, pos)?;
|
||||
let dst_addr: A = parse_addr(buf, pos)?;
|
||||
let src_port = parse_port(buf, pos, b' ')?;
|
||||
@@ -54,20 +57,20 @@ fn parse_addrs<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<Address,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
let mut pos = 0;
|
||||
|
||||
if buf.len() < MIN_LENGTH {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if !buf.starts_with(GREETING) {
|
||||
return Err(Error::Invalid);
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
pos += GREETING.len() + 1;
|
||||
|
||||
let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
|
||||
let Some(rest) = read_until(&buf[pos..], b'\r') else {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
};
|
||||
pos += rest.len() + 1;
|
||||
|
||||
@@ -79,14 +82,14 @@ fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
match proto {
|
||||
b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
|
||||
b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
|
||||
_ => return Err(Error::Invalid),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
};
|
||||
|
||||
match buf.get(pos) {
|
||||
Some(b'\n') => pos += 1,
|
||||
None => return Err(Error::BufferTooShort),
|
||||
_ => return Err(Error::Invalid),
|
||||
None => return Err(ProxyProtocolError::BufferTooShort),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
}
|
||||
|
||||
Ok((Header(addrs, Cow::default()), pos))
|
||||
@@ -95,11 +98,13 @@ fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
/// Decode a version 1 PROXY header from a buffer.
|
||||
///
|
||||
/// Returns the decoded header and the number of bytes consumed from the buffer.
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
// Guard against a malicious client sending a very long header, since it is a
|
||||
// delimited protocol.
|
||||
match decode_inner(buf) {
|
||||
Err(Error::BufferTooShort) if buf.len() >= MAX_LENGTH => Err(Error::Invalid),
|
||||
Err(ProxyProtocolError::BufferTooShort) if buf.len() >= MAX_LENGTH => {
|
||||
Err(ProxyProtocolError::Invalid)
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
header::{Address, Error, Header, Protocol},
|
||||
header::{Address, Header, Protocol, ProxyProtocolError},
|
||||
utils::AddressFamily,
|
||||
};
|
||||
|
||||
@@ -17,12 +17,12 @@ fn parse_addrs<T: AddressFamily>(
|
||||
pos: &mut usize,
|
||||
rest: &mut usize,
|
||||
protocol: Protocol,
|
||||
) -> Result<Address, Error> {
|
||||
) -> Result<Address, ProxyProtocolError> {
|
||||
if buf.len() < *pos + T::BYTES * 2 + 4 {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if *rest < T::BYTES * 2 + 4 {
|
||||
return Err(Error::Invalid);
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
|
||||
let addr = Address {
|
||||
@@ -46,28 +46,28 @@ fn parse_addrs<T: AddressFamily>(
|
||||
/// Decode a version 2 PROXY header from a buffer.
|
||||
///
|
||||
/// Returns the decoded header and the number of bytes consumed from the buffer.
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
|
||||
let mut pos = 0;
|
||||
|
||||
if buf.len() < MIN_LENGTH {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
if !buf.starts_with(GREETING) {
|
||||
return Err(Error::Invalid);
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
pos += GREETING.len();
|
||||
|
||||
let is_local = match buf[pos] {
|
||||
0x20 => true,
|
||||
0x21 => false,
|
||||
_ => return Err(Error::Invalid),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
};
|
||||
let protocol = buf[pos + 1];
|
||||
let mut rest: usize = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]).into();
|
||||
pos += 4;
|
||||
|
||||
if buf.len() < pos + rest {
|
||||
return Err(Error::BufferTooShort);
|
||||
return Err(ProxyProtocolError::BufferTooShort);
|
||||
}
|
||||
|
||||
let addr_info = match protocol {
|
||||
@@ -99,13 +99,13 @@ pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
|
||||
0x31 | 0x32 => {
|
||||
// AF_UNIX - we don't parse it, but don't reject it either in case we need the TLVs
|
||||
if rest < AF_UNIX_ADDRS_LEN {
|
||||
return Err(Error::Invalid);
|
||||
return Err(ProxyProtocolError::Invalid);
|
||||
}
|
||||
rest -= AF_UNIX_ADDRS_LEN;
|
||||
pos += AF_UNIX_ADDRS_LEN;
|
||||
None
|
||||
}
|
||||
_ => return Err(Error::Invalid),
|
||||
_ => return Err(ProxyProtocolError::Invalid),
|
||||
};
|
||||
|
||||
let tlv_data = Cow::Borrowed(&buf[pos..pos + rest]);
|
||||
|
||||
@@ -65,9 +65,7 @@ mod json {
|
||||
use tracing_subscriber::{layer::Layer, registry::LookupSpan};
|
||||
|
||||
pub(super) fn layer<S>() -> impl Layer<S>
|
||||
where
|
||||
S: Subscriber + for<'lookup> LookupSpan<'lookup>,
|
||||
{
|
||||
where S: Subscriber + for<'lookup> LookupSpan<'lookup> {
|
||||
let mut json_layer = json_subscriber::fmt::layer()
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
|
||||
Reference in New Issue
Block a user