mirror of
https://github.com/goauthentik/authentik
synced 2026-04-25 17:15:26 +02:00
87 lines
2.3 KiB
Rust
87 lines
2.3 KiB
Rust
use std::{io, time::Duration};
|
|
|
|
use axum::{Extension, middleware::AddExtension};
|
|
use axum_server::accept::{Accept, DefaultAcceptor};
|
|
use futures::future::BoxFuture;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
use tower::Layer as _;
|
|
use tracing::instrument;
|
|
|
|
use crate::tokio::proxy_protocol::{ProxyProtocolStream, header::Header};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub(crate) struct ProxyProtocolState {
|
|
pub(crate) header: Option<Header<'static>>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub(crate) struct ProxyProtocolAcceptor<A = DefaultAcceptor> {
|
|
inner: A,
|
|
parsing_timeout: Duration,
|
|
}
|
|
|
|
impl ProxyProtocolAcceptor {
|
|
pub(crate) fn new() -> Self {
|
|
let inner = DefaultAcceptor::new();
|
|
|
|
#[cfg(not(test))]
|
|
let parsing_timeout = Duration::from_secs(10);
|
|
|
|
// Don't force tests to wait too long
|
|
#[cfg(test)]
|
|
let parsing_timeout = Duration::from_secs(1);
|
|
|
|
Self {
|
|
inner,
|
|
parsing_timeout,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for ProxyProtocolAcceptor {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl<A> ProxyProtocolAcceptor<A> {
|
|
pub(crate) fn acceptor<Acceptor>(self, acceptor: Acceptor) -> ProxyProtocolAcceptor<Acceptor> {
|
|
ProxyProtocolAcceptor {
|
|
inner: acceptor,
|
|
parsing_timeout: self.parsing_timeout,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<A, I, S> Accept<I, S> for ProxyProtocolAcceptor<A>
|
|
where
|
|
A: Accept<I, S> + Clone + Send + 'static,
|
|
A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
|
|
A::Service: Send,
|
|
A::Future: Send,
|
|
I: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
S: Send + 'static,
|
|
{
|
|
type Future = BoxFuture<'static, io::Result<(Self::Stream, Self::Service)>>;
|
|
type Service = AddExtension<A::Service, ProxyProtocolState>;
|
|
type Stream = ProxyProtocolStream<A::Stream>;
|
|
|
|
#[instrument(skip_all)]
|
|
fn accept(&self, stream: I, service: S) -> Self::Future {
|
|
let acceptor = self.inner.clone();
|
|
|
|
Box::pin(async move {
|
|
let (stream, service) = acceptor.accept(stream, service).await?;
|
|
let stream = ProxyProtocolStream::new(stream).await?;
|
|
|
|
let proxy_protocol_state = ProxyProtocolState {
|
|
header: stream.header().cloned(),
|
|
};
|
|
|
|
let service = Extension(proxy_protocol_state).layer(service);
|
|
|
|
Ok((stream, service))
|
|
})
|
|
}
|
|
}
|