Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
Marc 'risson' Schmitt
2026-03-19 19:20:34 +01:00
parent f3341a4b83
commit 04c066d8b0
26 changed files with 177 additions and 133 deletions

View File

@@ -150,6 +150,8 @@ perf = { priority = -1, level = "warn" }
style = { priority = -1, level = "warn" }
suspicious = { priority = -1, level = "warn" }
### and disable the ones we don't want
### cargo group
multiple_crate_versions = "allow"
### pedantic group
redundant_closure_for_method_calls = "allow"
struct_field_names = "allow"
@@ -170,7 +172,6 @@ create_dir = "warn"
dbg_macro = "warn"
default_numeric_fallback = "warn"
disallowed_script_idents = "warn"
doc_paragraphs_missing_punctuation = "warn"
empty_drop = "warn"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"

View File

@@ -4,7 +4,7 @@ use axum::{Extension, middleware::AddExtension};
use axum_server::accept::{Accept, DefaultAcceptor};
use futures::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use tower::Layer;
use tower::Layer as _;
use tracing::instrument;
use crate::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};

View File

@@ -3,7 +3,7 @@ use axum_server::{accept::Accept, tls_rustls::RustlsAcceptor};
use futures::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{rustls::pki_types::CertificateDer, server::TlsStream};
use tower::Layer;
use tower::Layer as _;
use tracing::instrument;
#[derive(Clone, Debug)]

View File

@@ -1,7 +1,7 @@
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use axum::{
Extension, RequestPartsExt,
Extension, RequestPartsExt as _,
extract::{ConnectInfo, FromRequestParts, Request},
http::request::Parts,
middleware::Next,
@@ -81,7 +81,7 @@ pub(crate) async fn client_ip_middleware(request: Request, next: Next) -> Respon
#[cfg(test)]
mod tests {
use std::net::{Ipv4Addr, Ipv6Addr};
use std::net::Ipv4Addr;
use axum::{body::Body, http::Request};

View File

@@ -1,5 +1,5 @@
use axum::{
Extension, RequestPartsExt,
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{
header::{FORWARDED, HOST},
@@ -7,7 +7,7 @@ use axum::{
status::StatusCode,
},
middleware::Next,
response::{IntoResponse, Response},
response::{IntoResponse as _, Response},
};
use forwarded_header_value::ForwardedHeaderValue;
use tracing::{Span, instrument};

View File

@@ -1,5 +1,5 @@
use axum::{
Extension, RequestPartsExt,
Extension, RequestPartsExt as _,
extract::{FromRequestParts, Request},
http::{self, header::FORWARDED, request::Parts},
middleware::Next,

View File

@@ -1,7 +1,7 @@
use std::net::SocketAddr;
use axum::{
Extension, RequestPartsExt,
Extension, RequestPartsExt as _,
extract::{ConnectInfo, FromRequestParts, Request},
http::request::Parts,
middleware::Next,

View File

@@ -2,7 +2,7 @@ use std::collections::HashMap;
use axum::{extract::Request, middleware::Next, response::Response};
use tokio::time::Instant;
use tracing::{Instrument, field, info, info_span, trace};
use tracing::{Instrument as _, field, info, info_span, trace};
use crate::config;

View File

@@ -7,7 +7,7 @@ use eyre::{Report, Result};
use rustls::{
RootCertStore,
crypto::CryptoProvider,
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject as _},
server::ClientHello,
sign::CertifiedKey,
};

View File

@@ -7,7 +7,7 @@ use std::{
use arc_swap::{ArcSwap, Guard};
use eyre::Result;
use notify::{RecommendedWatcher, Watcher};
use notify::{RecommendedWatcher, Watcher as _};
use serde_json::{Map, Value};
use tokio::sync::mpsc;
use tracing::{error, info, warn};
@@ -86,7 +86,7 @@ impl Config {
"file" => {
let path = uri.path();
match read_to_string(path).map(|s| s.trim().to_owned()) {
Ok(value) => return (value.to_owned(), Some(PathBuf::from(path))),
Ok(value) => return (value, Some(PathBuf::from(path))),
Err(err) => {
error!("failed to read config value from {path}: {err}");
return (fallback, Some(PathBuf::from(path)));
@@ -96,7 +96,7 @@ impl Config {
"env" => {
if let Some(var) = uri.host_str() {
if let Ok(value) = env::var(var) {
return (value.to_owned(), None);
return (value, None);
}
return (fallback, None);
}
@@ -141,10 +141,10 @@ impl Config {
(value, file_paths)
}
fn load(config_paths: &[PathBuf]) -> Result<(Config, Vec<PathBuf>)> {
fn load(config_paths: &[PathBuf]) -> Result<(Self, Vec<PathBuf>)> {
let raw = Self::load_raw(config_paths)?;
let (expanded, file_paths) = Self::expand(raw);
let config: Config = serde_json::from_value(expanded)?;
let config: Self = serde_json::from_value(expanded)?;
Ok((config, file_paths))
}
}

View File

@@ -1,8 +1,8 @@
use std::{str::FromStr, sync::OnceLock, time::Duration};
use std::{str::FromStr as _, sync::OnceLock, time::Duration};
use eyre::Result;
use sqlx::{
Executor, PgPool,
Executor as _, PgPool,
postgres::{PgConnectOptions, PgPoolOptions, PgSslMode},
};
use tracing::{info, log::LevelFilter, trace};

View File

@@ -1,7 +1,9 @@
use std::sync::Arc;
use std::{
process::exit,
sync::atomic::{AtomicUsize, Ordering},
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use ::tracing::{error, info, trace};

View File

@@ -50,7 +50,7 @@ pub(super) async fn metrics_handler(State(state): State<Arc<Metrics>>) -> Result
mod python {
use eyre::{Report, Result};
use pyo3::{
IntoPyObjectExt,
IntoPyObjectExt as _,
ffi::c_str,
prelude::*,
types::{PyBytes, PyDict},
@@ -81,7 +81,7 @@ output = generate_latest(registry)
.get_item("output")?
.unwrap_or(PyBytes::new(py, &[]).into_bound_py_any(py)?)
.cast::<PyBytes>()
.unwrap_or(&PyBytes::new(py, &[]))
.map_or_else(|_| PyBytes::new(py, &[]), |v| v.to_owned())
.as_bytes()
.to_owned();
Ok::<_, Report>(metrics)

View File

@@ -42,7 +42,7 @@ async fn run_upkeep(arbiter: Arbiter, state: Arc<Metrics>) -> Result<()> {
loop {
tokio::select! {
() = tokio::time::sleep(Duration::from_secs(5)) => {
let state_clone = state.clone();
let state_clone = Arc::clone(&state);
tokio::task::spawn_blocking(move || state_clone.prometheus.run_upkeep()).await?;
},
() = arbiter.shutdown() => return Ok(())
@@ -62,12 +62,12 @@ fn build_router(state: Arc<Metrics>) -> Router {
pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
let arbiter = tasks.arbiter();
let metrics = Arc::new(Metrics::new()?);
let router = build_router(metrics.clone());
let router = build_router(Arc::clone(&metrics));
tasks
.build_task()
.name(&format!("{}::run_upkeep", module_path!(),))
.spawn(run_upkeep(arbiter.clone(), metrics.clone()))?;
.spawn(run_upkeep(arbiter, Arc::clone(&metrics)))?;
for addr in config::get().listen.metrics.iter().copied() {
server::start_plain(tasks, "metrics", router.clone(), addr)?;
@@ -77,11 +77,7 @@ pub(super) fn run(tasks: &mut Tasks) -> Result<Arc<Metrics>> {
tasks,
"metrics",
router,
unix::net::SocketAddr::from_pathname({
let mut path = temp_dir();
path.push("authentik-metrics.sock");
path
})?,
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik-metrics.sock"))?,
)?;
Ok(metrics)

View File

@@ -40,6 +40,13 @@ impl std::fmt::Display for Mode {
}
}
impl From<Mode> for u8 {
#[expect(clippy::as_conversions, reason = "repr of enum is u8")]
fn from(value: Mode) -> Self {
value as Self
}
}
impl Mode {
pub(crate) fn get() -> Self {
match MODE.load(Ordering::Relaxed) {
@@ -57,7 +64,7 @@ impl Mode {
pub(crate) fn set(mode: Self) -> Result<()> {
std::fs::write(mode_path(), mode.to_string())?;
MODE.store(mode as u8, Ordering::SeqCst);
MODE.store(mode.into(), Ordering::SeqCst);
Ok(())
}

View File

@@ -7,6 +7,10 @@ use crate::arbiter::{Arbiter, Tasks};
#[derive(Debug, FromArgs, PartialEq)]
/// Run the authentik proxy outpost.
#[argh(subcommand, name = "proxy")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(crate) struct Cli {}
pub(crate) mod tls {
@@ -15,9 +19,10 @@ pub(crate) mod tls {
use rustls::{server::ClientHello, sign::CertifiedKey};
#[derive(Debug)]
pub(crate) struct CertResolver {}
pub(crate) struct CertResolver;
impl CertResolver {
#[expect(clippy::unused_self, reason = "still WIP")]
pub(crate) fn resolve(&self, _client_hello: &ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
None
}

View File

@@ -12,7 +12,7 @@ use axum::{
response::{IntoResponse, Response},
routing::any,
};
use http_body_util::BodyExt;
use http_body_util::BodyExt as _;
use serde_json::json;
use crate::{
@@ -226,7 +226,7 @@ pub(super) fn build_router(server: Arc<Server>) -> Router {
Router::new()
.route("/-/metrics/", any((StatusCode::NOT_FOUND, "not found")))
.route("/-/health/ready/", any(health_ready))
.with_state(server.clone())
.with_state(Arc::clone(&server))
.merge(super::r#static::build_router()),
true,
)
@@ -264,9 +264,9 @@ mod websockets {
CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, UPGRADE,
},
},
response::{IntoResponse, Response},
response::{IntoResponse as _, Response},
};
use futures::{SinkExt, StreamExt};
use futures::{SinkExt as _, StreamExt as _};
use hyper_util::rt::TokioIo;
use tokio::{net::UnixStream, sync::mpsc};
use tokio_tungstenite::{
@@ -373,7 +373,7 @@ mod websockets {
Message::Close(_) => {
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
close_tx.send(()).await.ok();
let _ = close_tx.send(()).await;
client_closed = true;
break;
}
@@ -391,7 +391,7 @@ mod websockets {
}
if !client_closed {
upstream_sender.send(Message::Close(None)).await?;
close_tx.send(()).await.ok();
let _ = close_tx.send(()).await;
}
Ok::<_, AppError>(())
});
@@ -404,7 +404,7 @@ mod websockets {
Message::Close(_) => {
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
close_tx_upstream.send(()).await.ok();
let _ = close_tx_upstream.send(()).await;
upstream_closed = true;
break;
}
@@ -422,7 +422,7 @@ mod websockets {
}
if !upstream_closed {
client_sender.send(Message::Close(None)).await?;
close_tx_upstream.send(()).await.ok();
let _ = close_tx_upstream.send(()).await;
}
Ok::<_, AppError>(())
});

View File

@@ -27,7 +27,7 @@ use tokio::{
sync::{Mutex, broadcast::error::RecvError},
time::Instant,
};
use tower::ServiceExt;
use tower::ServiceExt as _;
use tower_http::timeout::TimeoutLayer;
use tracing::{info, trace, warn};
@@ -47,6 +47,10 @@ mod tls;
#[derive(Debug, Default, FromArgs, PartialEq)]
/// Run the authentik server.
#[argh(subcommand, name = "server")]
#[expect(
clippy::empty_structs_with_brackets,
reason = "argh doesn't support unit structs"
)]
pub(super) struct Cli {}
pub(crate) struct Server {
@@ -90,10 +94,12 @@ impl Server {
signal = signal.as_str(),
"sending shutdown signal to gunicorn"
);
if let Some(id) = self.gunicorn.lock().await.id() {
let mut gunicorn = self.gunicorn.lock().await;
if let Some(id) = gunicorn.id() {
kill(Pid::from_raw(id.cast_signed()), signal)?;
}
self.gunicorn.lock().await.wait().await?;
gunicorn.wait().await?;
drop(gunicorn);
Ok(())
}
@@ -221,9 +227,9 @@ pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
tasks
.build_task()
.name(&format!("{}::watch_server", module_path!()))
.spawn(watch_server(arbiter.clone(), server.clone()))?;
.spawn(watch_server(arbiter.clone(), Arc::clone(&server)))?;
let router = build_router(server.clone());
let router = build_router(Arc::clone(&server));
for addr in config.listen.http.iter().copied() {
server::start_plain(tasks, "server", router.clone(), addr)?;
@@ -242,11 +248,7 @@ pub(super) fn run(_cli: Cli, tasks: &mut Tasks) -> Result<Arc<Server>> {
tasks,
"server",
router,
unix::net::SocketAddr::from_pathname({
let mut path = temp_dir();
path.push("authentik.sock");
path
})?,
unix::net::SocketAddr::from_pathname(temp_dir().join("authentik.sock"))?,
)?;
Ok(server)

View File

@@ -1,3 +1,5 @@
use std::fmt::Write as _;
use aws_lc_rs::digest;
use axum::{
Router,
@@ -7,7 +9,7 @@ use axum::{
header::{CACHE_CONTROL, CONTENT_SECURITY_POLICY, VARY},
},
middleware::{self, Next},
response::{IntoResponse, Response},
response::{IntoResponse as _, Response},
routing::any,
};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header};
@@ -58,8 +60,10 @@ fn is_storage_token_valid(usage: &str, secret_key: &str, request: &Request) -> b
let key_hex_digest = key_digest
.as_ref()
.iter()
.map(|b| format!("{b:02x}"))
.collect::<String>();
.fold(String::new(), |mut acc, b| {
let _ = write!(acc, "{b:02x}");
acc
});
let mut validation = Validation::new(token_header.alg);
validation.validate_exp = false;
@@ -157,12 +161,12 @@ pub(crate) fn build_router() -> Router {
let static_fs = ServeDir::new("./web/authentik/").append_index_html_on_directories(false);
router = router.nest_service("/static/dist/", dist_fs.clone());
router = router.nest_service("/static/authentik/", static_fs.clone());
router = router.nest_service("/static/authentik/", static_fs);
router = router.nest_service("/if/flow/{flow_slug}/assets/", dist_fs.clone());
router = router.nest_service("/if/admin/assets/", dist_fs.clone());
router = router.nest_service("/if/user/assets/", dist_fs.clone());
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs.clone());
router = router.nest_service("/if/rac/{app_slug}/assets/", dist_fs);
let default_backend = &config.storage.backend;
let media_backend = config
@@ -171,19 +175,19 @@ pub(crate) fn build_router() -> Router {
.clone()
.unwrap_or_default()
.backend
.unwrap_or(default_backend.clone());
.unwrap_or_else(|| default_backend.clone());
let reports_backend = config
.storage
.reports
.clone()
.unwrap_or_default()
.backend
.unwrap_or(default_backend.clone());
.unwrap_or_else(|| default_backend.clone());
let default_path = &config.storage.file.path;
if media_backend == "file" {
let mut media_path = config
let media_path = config
.storage
.media
.clone()
@@ -191,8 +195,8 @@ pub(crate) fn build_router() -> Router {
.file
.unwrap_or_default()
.path
.unwrap_or(default_path.clone());
media_path.push("media");
.unwrap_or_else(|| default_path.clone())
.join("media");
let media_fs = ServeDir::new(media_path).append_index_html_on_directories(false);
let media_router =
@@ -209,7 +213,7 @@ pub(crate) fn build_router() -> Router {
}
if reports_backend == "file" {
let mut reports_path = config
let reports_path = config
.storage
.reports
.clone()
@@ -217,8 +221,8 @@ pub(crate) fn build_router() -> Router {
.file
.unwrap_or_default()
.path
.unwrap_or(default_path.clone());
reports_path.push("reports");
.unwrap_or_else(|| default_path.clone())
.join("reports");
let reports_fs = ServeDir::new(reports_path).append_index_html_on_directories(false);
let reports_router =

View File

@@ -71,10 +71,14 @@ struct CertResolver {
fallback: Arc<CertifiedKey>,
}
#[expect(
clippy::missing_trait_methods,
reason = "the provided methods are sensible enough"
)]
impl ResolvesServerCert for CertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if client_hello.server_name().is_none() {
Some(self.fallback.clone())
Some(Arc::clone(&self.fallback))
} else if let Some(resolver) = &self.proxy_resolver
&& let Some(cert) = resolver.resolve(&client_hello)
{
@@ -82,7 +86,7 @@ impl ResolvesServerCert for CertResolver {
} else if let Some(cert) = self.core_resolver.resolve(&client_hello) {
Some(cert)
} else {
Some(self.fallback.clone())
Some(Arc::clone(&self.fallback))
}
}
}

View File

@@ -52,8 +52,12 @@ pub(crate) struct Tlvs<'a> {
buf: &'a [u8],
}
#[expect(
clippy::missing_trait_methods,
reason = "we don't need to implement the other methods here"
)]
impl<'a> Iterator for Tlvs<'a> {
type Item = Result<Tlv<'a>, Error>;
type Item = Result<Tlv<'a>, ProxyProtocolError>;
fn next(&mut self) -> Option<Self::Item> {
if self.buf.is_empty() {
@@ -64,7 +68,7 @@ impl<'a> Iterator for Tlvs<'a> {
match self
.buf
.get(1..3)
.map(|s| u16::from_be_bytes(s.try_into().expect("infallible")) as usize)
.map(|s| -> usize { u16::from_be_bytes(s.try_into().expect("infallible")).into() })
{
Some(u) if u + 3 <= self.buf.len() => {
let (ret, new) = self.buf.split_at(3 + u);
@@ -75,7 +79,7 @@ impl<'a> Iterator for Tlvs<'a> {
_ => {
// Malformed TLV, cannot continue
self.buf = &[];
Some(Err(Error::Invalid))
Some(Err(ProxyProtocolError::Invalid))
}
}
}
@@ -172,44 +176,58 @@ pub(crate) enum Tlv<'a> {
}
impl<'a> Tlv<'a> {
fn decode(kind: u8, data: &'a [u8]) -> Result<Tlv<'a>, Error> {
fn decode(kind: u8, data: &'a [u8]) -> Result<Self, ProxyProtocolError> {
match kind {
0x01 => Ok(Self::Alpn(data.into())),
0x02 => Ok(Self::Authority(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x03 => Ok(Self::Crc32c(u32::from_be_bytes(
data.try_into().map_err(|_| Error::Invalid)?,
data.try_into().map_err(|_| ProxyProtocolError::Invalid)?,
))),
0x04 => Ok(Self::Noop(data.len())),
0x05 => Ok(Self::UniqueId(data.into())),
0x20 => Ok(Tlv::Ssl(SslInfo(
*data.first().ok_or(Error::Invalid)?,
*data.first().ok_or(ProxyProtocolError::Invalid)?,
u32::from_be_bytes(
data.get(1..5)
.ok_or(Error::Invalid)?
.ok_or(ProxyProtocolError::Invalid)?
.try_into()
.map_err(|_| Error::Invalid)?,
.map_err(|_| ProxyProtocolError::Invalid)?,
),
data.get(5..).ok_or(Error::Invalid)?.into(),
data.get(5..).ok_or(ProxyProtocolError::Invalid)?.into(),
))),
0x21 => Ok(Self::SslVersion(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x22 => Ok(Self::SslCn(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x23 => Ok(Self::SslCipher(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x24 => Ok(Self::SslSigAlg(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x25 => Ok(Self::SslKeyAlg(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
0x30 => Ok(Self::Netns(
from_utf8(data).map_err(|_| Error::Invalid)?.into(),
from_utf8(data)
.map_err(|_| ProxyProtocolError::Invalid)?
.into(),
)),
t => Ok(Self::Custom(t, data.into())),
}
@@ -346,20 +364,21 @@ impl<'a> Header<'a> {
/// Attempt to parse a PROXY protocol header from the given buffer
///
/// Returns the parsed header and the number of bytes consumed from the buffer. If the header
/// is incomplete, returns [`Error::BufferTooShort`] so more data can be read from the socket.
/// is incomplete, returns [`ProxyProtocolError::BufferTooShort`] so more data can be read from
/// the socket.
///
/// If the header is malformed or unsupported, returns [`Error::Invalid`].
/// If the header is malformed or unsupported, returns [`ProxyProtocolError::Invalid`].
///
/// This function will borrow the buffer for the lifetime of the returned header. If
/// you need to keep the header around for longer than the buffer, use
/// [`Header::into_owned`].
#[instrument(skip_all)]
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), Error> {
pub(super) fn parse(buf: &'a [u8]) -> Result<(Self, usize), ProxyProtocolError> {
match buf.first() {
Some(b'P') => super::v1::decode(buf),
Some(b'\r') => super::v2::decode(buf),
None => Err(Error::BufferTooShort),
_ => Err(Error::Invalid),
None => Err(ProxyProtocolError::BufferTooShort),
_ => Err(ProxyProtocolError::Invalid),
}
}
@@ -435,7 +454,7 @@ impl<'a> Header<'a> {
}
#[derive(Debug, PartialEq, Eq, Error)]
pub(crate) enum Error {
pub(crate) enum ProxyProtocolError {
#[error("The buffer is too short to contain a complete PROXY protocol header")]
BufferTooShort,
#[error("The PROXY protocol header is malformed")]
@@ -489,7 +508,7 @@ mod tests {
for i in 0..case.len() {
assert!(matches!(
Header::parse(&case[..i]),
Err(Error::BufferTooShort)
Err(ProxyProtocolError::BufferTooShort)
));
}
@@ -500,7 +519,10 @@ mod tests {
#[test]
fn test_parse_proxy_header_v1_unterminated() {
let line = b"PROXY TCP4 THISISSTORYALLABOUTHOWMYLIFEGOTFLIPPEDTURNEDUPSIDEDOWNANDIDLIKETOTAKEAMINUTEJUSTSITRIGHTTHEREANDILLTELLYOUHOWIGOTTHEPRINCEOFAIR";
assert!(matches!(Header::parse(line), Err(Error::Invalid)));
assert!(matches!(
Header::parse(line),
Err(ProxyProtocolError::Invalid)
));
}
#[test]

View File

@@ -9,10 +9,10 @@ use std::{
use eyre::{Result, eyre};
use pin_project_lite::pin_project;
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::io::{AsyncBufRead, AsyncRead, AsyncReadExt as _, AsyncWrite, ReadBuf};
use tracing::instrument;
use crate::tokio::proxy_protocol::header::{Error, Header};
use crate::tokio::proxy_protocol::header::{Header, ProxyProtocolError};
pub(crate) mod header;
mod utils;
@@ -67,8 +67,7 @@ impl<S> Deref for ProxyProtocolStream<S> {
}
impl<S> ProxyProtocolStream<S>
where
S: AsyncRead + Unpin,
where S: AsyncRead + Unpin
{
#[instrument(skip_all)]
pub(crate) async fn new(mut stream: S) -> Result<Self, io::Error> {
@@ -94,7 +93,7 @@ where
header: Some(header),
});
}
Err(Error::BufferTooShort) => {}
Err(ProxyProtocolError::BufferTooShort) => {}
// Something went wrong parsing the PROXY protocol. We assume that we weren't meant
// to parse it, and that this is just a regular stream without the PROXY protocol.
Err(_) => {
@@ -110,8 +109,7 @@ where
}
impl<S> AsyncRead for ProxyProtocolStream<S>
where
S: AsyncRead,
where S: AsyncRead
{
#[instrument(skip_all)]
fn poll_read(
@@ -135,8 +133,7 @@ where
}
impl<S> AsyncBufRead for ProxyProtocolStream<S>
where
S: AsyncBufRead,
where S: AsyncBufRead
{
#[instrument(skip_all)]
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
@@ -168,8 +165,7 @@ where
}
impl<S> AsyncWrite for ProxyProtocolStream<S>
where
S: AsyncWrite,
where S: AsyncWrite
{
fn poll_write(
self: Pin<&mut Self>,

View File

@@ -19,6 +19,7 @@ pub(super) trait AddressFamily: FromStr + Into<IpAddr> {
}
impl AddressFamily for Ipv4Addr {
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
const BYTES: usize = (Self::BITS / 8) as usize;
fn from_slice(slice: &[u8]) -> Self {
@@ -28,6 +29,7 @@ impl AddressFamily for Ipv4Addr {
}
impl AddressFamily for Ipv6Addr {
#[expect(clippy::as_conversions, reason = "will always be in bounds")]
const BYTES: usize = (Self::BITS / 8) as usize;
fn from_slice(slice: &[u8]) -> Self {

View File

@@ -5,7 +5,7 @@ use std::{
};
use super::{
header::{Address, Error, Header, Protocol},
header::{Address, Header, Protocol, ProxyProtocolError},
utils::{AddressFamily, read_until},
};
@@ -15,33 +15,36 @@ const UNKNOWN: &[u8] = b"PROXY UNKNOWN\r\n";
// All other valid PROXY headers are longer than this
const MIN_LENGTH: usize = UNKNOWN.len();
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, Error> {
fn parse_addr<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<A, ProxyProtocolError> {
let Some(address) = read_until(&buf[*pos..], b' ') else {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
};
let addr = from_utf8(address)
.map_err(|_| Error::Invalid)
.and_then(|s| A::from_str(s).map_err(|_| Error::Invalid))?;
.map_err(|_| ProxyProtocolError::Invalid)
.and_then(|s| A::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
*pos += address.len() + 1;
Ok(addr)
}
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, Error> {
fn parse_port(buf: &[u8], pos: &mut usize, delim: u8) -> Result<u16, ProxyProtocolError> {
let Some(port) = read_until(&buf[*pos..], delim) else {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
};
let p = from_utf8(port)
.map_err(|_| Error::Invalid)
.and_then(|s| u16::from_str(s).map_err(|_| Error::Invalid))?;
.map_err(|_| ProxyProtocolError::Invalid)
.and_then(|s| u16::from_str(s).map_err(|_| ProxyProtocolError::Invalid))?;
*pos += port.len() + 1;
Ok(p)
}
fn parse_addrs<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<Address, Error> {
fn parse_addrs<A: AddressFamily>(
buf: &[u8],
pos: &mut usize,
) -> Result<Address, ProxyProtocolError> {
let src_addr: A = parse_addr(buf, pos)?;
let dst_addr: A = parse_addr(buf, pos)?;
let src_port = parse_port(buf, pos, b' ')?;
@@ -54,20 +57,20 @@ fn parse_addrs<A: AddressFamily>(buf: &[u8], pos: &mut usize) -> Result<Address,
})
}
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
let mut pos = 0;
if buf.len() < MIN_LENGTH {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
}
if !buf.starts_with(GREETING) {
return Err(Error::Invalid);
return Err(ProxyProtocolError::Invalid);
}
pos += GREETING.len() + 1;
let addrs = if buf[pos..].starts_with(b"UNKNOWN") {
let Some(rest) = read_until(&buf[pos..], b'\r') else {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
};
pos += rest.len() + 1;
@@ -79,14 +82,14 @@ fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
match proto {
b"TCP4 " => Some(parse_addrs::<Ipv4Addr>(buf, &mut pos)?),
b"TCP6 " => Some(parse_addrs::<Ipv6Addr>(buf, &mut pos)?),
_ => return Err(Error::Invalid),
_ => return Err(ProxyProtocolError::Invalid),
}
};
match buf.get(pos) {
Some(b'\n') => pos += 1,
None => return Err(Error::BufferTooShort),
_ => return Err(Error::Invalid),
None => return Err(ProxyProtocolError::BufferTooShort),
_ => return Err(ProxyProtocolError::Invalid),
}
Ok((Header(addrs, Cow::default()), pos))
@@ -95,11 +98,13 @@ fn decode_inner(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
/// Decode a version 1 PROXY header from a buffer.
///
/// Returns the decoded header and the number of bytes consumed from the buffer.
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
// Guard against a malicious client sending a very long header, since it is a
// delimited protocol.
match decode_inner(buf) {
Err(Error::BufferTooShort) if buf.len() >= MAX_LENGTH => Err(Error::Invalid),
Err(ProxyProtocolError::BufferTooShort) if buf.len() >= MAX_LENGTH => {
Err(ProxyProtocolError::Invalid)
}
other => other,
}
}

View File

@@ -4,7 +4,7 @@ use std::{
};
use super::{
header::{Address, Error, Header, Protocol},
header::{Address, Header, Protocol, ProxyProtocolError},
utils::AddressFamily,
};
@@ -17,12 +17,12 @@ fn parse_addrs<T: AddressFamily>(
pos: &mut usize,
rest: &mut usize,
protocol: Protocol,
) -> Result<Address, Error> {
) -> Result<Address, ProxyProtocolError> {
if buf.len() < *pos + T::BYTES * 2 + 4 {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
}
if *rest < T::BYTES * 2 + 4 {
return Err(Error::Invalid);
return Err(ProxyProtocolError::Invalid);
}
let addr = Address {
@@ -46,28 +46,28 @@ fn parse_addrs<T: AddressFamily>(
/// Decode a version 2 PROXY header from a buffer.
///
/// Returns the decoded header and the number of bytes consumed from the buffer.
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), ProxyProtocolError> {
let mut pos = 0;
if buf.len() < MIN_LENGTH {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
}
if !buf.starts_with(GREETING) {
return Err(Error::Invalid);
return Err(ProxyProtocolError::Invalid);
}
pos += GREETING.len();
let is_local = match buf[pos] {
0x20 => true,
0x21 => false,
_ => return Err(Error::Invalid),
_ => return Err(ProxyProtocolError::Invalid),
};
let protocol = buf[pos + 1];
let mut rest: usize = u16::from_be_bytes([buf[pos + 2], buf[pos + 3]]).into();
pos += 4;
if buf.len() < pos + rest {
return Err(Error::BufferTooShort);
return Err(ProxyProtocolError::BufferTooShort);
}
let addr_info = match protocol {
@@ -99,13 +99,13 @@ pub(super) fn decode(buf: &[u8]) -> Result<(Header<'_>, usize), Error> {
0x31 | 0x32 => {
// AF_UNIX - we don't parse it, but don't reject it either in case we need the TLVs
if rest < AF_UNIX_ADDRS_LEN {
return Err(Error::Invalid);
return Err(ProxyProtocolError::Invalid);
}
rest -= AF_UNIX_ADDRS_LEN;
pos += AF_UNIX_ADDRS_LEN;
None
}
_ => return Err(Error::Invalid),
_ => return Err(ProxyProtocolError::Invalid),
};
let tlv_data = Cow::Borrowed(&buf[pos..pos + rest]);

View File

@@ -65,9 +65,7 @@ mod json {
use tracing_subscriber::{layer::Layer, registry::LookupSpan};
pub(super) fn layer<S>() -> impl Layer<S>
where
S: Subscriber + for<'lookup> LookupSpan<'lookup>,
{
where S: Subscriber + for<'lookup> LookupSpan<'lookup> {
let mut json_layer = json_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true)