diff --git a/Cargo.lock b/Cargo.lock index 5c97db4ee8..a712be9697 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -133,6 +133,7 @@ dependencies = [ "client-ip", "durstr", "eyre", + "forwarded-header-value", "futures", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index 4a752f49ce..97ab79628d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ console-subscriber = "= 0.5.0" dotenvy = "= 0.15.7" durstr = "= 0.5.1" eyre = "= 0.6.12" +forwarded-header-value = "= 0.1.1" futures = "= 0.3.32" glob = "= 0.3.3" ipnet = { version = "= 2.12.0", features = ["serde"] } diff --git a/packages/ak-axum/Cargo.toml b/packages/ak-axum/Cargo.toml index 3eee20cc97..67f5df6c7c 100644 --- a/packages/ak-axum/Cargo.toml +++ b/packages/ak-axum/Cargo.toml @@ -16,6 +16,7 @@ axum.workspace = true client-ip.workspace = true durstr.workspace = true eyre.workspace = true +forwarded-header-value.workspace = true futures.workspace = true tokio-rustls.workspace = true tokio.workspace = true diff --git a/packages/ak-axum/src/extract/mod.rs b/packages/ak-axum/src/extract/mod.rs index 84430ddd6a..3837cec844 100644 --- a/packages/ak-axum/src/extract/mod.rs +++ b/packages/ak-axum/src/extract/mod.rs @@ -1,4 +1,5 @@ //! axum extractors to get information about a request. pub mod client_ip; +pub mod scheme; pub mod trusted_proxy; diff --git a/packages/ak-axum/src/extract/scheme.rs b/packages/ak-axum/src/extract/scheme.rs new file mode 100644 index 0000000000..4f82fc7a23 --- /dev/null +++ b/packages/ak-axum/src/extract/scheme.rs @@ -0,0 +1,252 @@ +//! axum extractor and middleware to get the request scheme. + +use axum::{ + Extension, RequestPartsExt as _, + extract::{FromRequestParts, Request}, + http::{self, header::FORWARDED, request::Parts}, + middleware::Next, + response::Response, +}; +use forwarded_header_value::{ForwardedHeaderValue, Protocol}; +use tracing::{Span, instrument}; + +use crate::{ + accept::{proxy_protocol::ProxyProtocolState, tls::TlsState}, + extract::trusted_proxy::TrustedProxy, +}; + +const X_FORWARDED_PROTO: &str = "X-Forwarded-Proto"; +const X_FORWARDED_SCHEME: &str = "X-Forwarded-Scheme"; + +/// Request scheme. +/// +/// The [`scheme_middleware`] must be added to the router before using this extractor, +/// otherwise this will result in requests erroring. +#[derive(Clone, Debug)] +pub struct Scheme(pub http::uri::Scheme); + +impl FromRequestParts for Scheme +where + S: Send + Sync, +{ + type Rejection = as FromRequestParts>::Rejection; + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + Extension::::from_request_parts(parts, state) + .await + .map(|Extension(scheme)| scheme) + } +} + +/// Get the scheme from the request. +#[instrument(skip_all)] +async fn extract_scheme(parts: &mut Parts) -> http::uri::Scheme { + let is_trusted = parts + .extract::() + .await + .unwrap_or(TrustedProxy(false)) + .0; + + if is_trusted { + if let Some(proto) = parts.headers.get(X_FORWARDED_PROTO) + && let Ok(proto) = proto.to_str() + && let Ok(scheme) = proto.to_lowercase().as_str().try_into() + { + return scheme; + } + + if let Some(proto) = parts.headers.get(X_FORWARDED_SCHEME) + && let Ok(proto) = proto.to_str() + && let Ok(scheme) = proto.to_lowercase().as_str().try_into() + { + return scheme; + } + + if let Some(forwarded) = parts.headers.get(FORWARDED) + && let Ok(forwarded) = forwarded.to_str() + && let Ok(forwarded) = ForwardedHeaderValue::from_forwarded(forwarded) + { + for stanza in forwarded.iter() { + if let Some(forwarded_proto) = &stanza.forwarded_proto { + let scheme = match forwarded_proto { + Protocol::Http => http::uri::Scheme::HTTP, + Protocol::Https => http::uri::Scheme::HTTPS, + }; + return scheme; + } + } + } + + if let Ok(Extension(proxy_protocol_state)) = + parts.extract::>().await + && let Some(header) = &proxy_protocol_state.header + && header.ssl().is_some() + { + return http::uri::Scheme::HTTPS; + } + } + + if parts.extract::>().await.is_ok() { + http::uri::Scheme::HTTPS + } else { + http::uri::Scheme::HTTP + } +} + +/// Middleware required by the [`Scheme`] extractor. +/// +/// Use with [`axum::middleware::from_fn`]. +pub async fn scheme_middleware(request: Request, next: Next) -> Response { + let (mut parts, body) = request.into_parts(); + + let scheme = extract_scheme(&mut parts).await; + Span::current().record("scheme", scheme.to_string()); + parts.extensions.insert::(Scheme(scheme)); + + let request = Request::from_parts(parts, body); + + next.run(request).await +} + +#[cfg(test)] +mod tests { + use axum::{body::Body, http::Request}; + + use super::*; + + #[tokio::test] + async fn x_forwarded_proto_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-proto", "https") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTPS,); + } + + #[tokio::test] + async fn x_forwarded_scheme_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-scheme", "https") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTPS,); + } + + #[tokio::test] + async fn forwarded_header_trusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("forwarded", "proto=https") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTPS,); + } + + #[tokio::test] + async fn x_forwarded_proto_untrusted() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-proto", "https") + .extension(TrustedProxy(false)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTP,); + } + + #[tokio::test] + async fn scheme_from_tls_state() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .extension(TlsState { + peer_certificates: None, + }) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTPS,); + } + + #[tokio::test] + async fn scheme_defaults_to_http() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTP,); + } + + #[tokio::test] + async fn priority_order() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-proto", "http") + .header("x-forwarded-scheme", "https") + .header("forwarded", "proto=https") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTP,); + } + + #[tokio::test] + async fn multiple_forwarded_stanzas() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("forwarded", "proto=http, proto=https") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTP,); + } + + #[tokio::test] + async fn test_scheme_case_insensitive() { + let (mut parts, _) = Request::builder() + .uri("http://example.com/path") + .header("x-forwarded-proto", "HTTPS") + .extension(TrustedProxy(true)) + .body(Body::empty()) + .expect("Failed to create request") + .into_parts(); + + let scheme = extract_scheme(&mut parts).await; + + assert_eq!(scheme, http::uri::Scheme::HTTPS,); + } +} diff --git a/packages/ak-axum/src/router.rs b/packages/ak-axum/src/router.rs index a86cb70c72..2caceeab0c 100644 --- a/packages/ak-axum/src/router.rs +++ b/packages/ak-axum/src/router.rs @@ -6,7 +6,10 @@ use tower::ServiceBuilder; use tower_http::timeout::TimeoutLayer; use crate::{ - extract::{client_ip::client_ip_middleware, trusted_proxy::trusted_proxy_middleware}, + extract::{ + client_ip::client_ip_middleware, scheme::scheme_middleware, + trusted_proxy::trusted_proxy_middleware, + }, tracing::{span_middleware, tracing_middleware}, }; @@ -28,7 +31,8 @@ pub fn wrap_router(router: Router, with_tracing: bool) -> Router { )) .layer(from_fn(span_middleware)) .layer(from_fn(trusted_proxy_middleware)) - .layer(from_fn(client_ip_middleware)); + .layer(from_fn(client_ip_middleware)) + .layer(from_fn(scheme_middleware)); if with_tracing { router.layer(service_builder.layer(from_fn(tracing_middleware))) } else { diff --git a/packages/ak-axum/src/tracing.rs b/packages/ak-axum/src/tracing.rs index 5f2bf196fe..0c7b69a91c 100644 --- a/packages/ak-axum/src/tracing.rs +++ b/packages/ak-axum/src/tracing.rs @@ -28,6 +28,7 @@ pub(crate) async fn span_middleware(request: Request, next: Next) -> Response { path = %request.uri(), method = %request.method(), remote = field::Empty, + scheme = field::Empty, http_headers = ?http_headers, ); next.run(request).instrument(span).await