diff --git a/Cargo.lock b/Cargo.lock index 03bda258019..5e13759eee2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -749,12 +749,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -3463,25 +3457,6 @@ dependencies = [ "svg_fmt", ] -[[package]] -name = "h2" -version = "0.3.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http 0.2.12", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "h2" version = "0.4.12" @@ -3572,21 +3547,6 @@ dependencies = [ "hashbrown 0.15.5", ] -[[package]] -name = "headers" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" -dependencies = [ - "base64 0.21.7", - "bytes", - "headers-core 0.2.0", - "http 0.2.12", - "httpdate", - "mime", - "sha1", -] - [[package]] name = "headers" version = "0.4.1" @@ -3595,22 +3555,13 @@ checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" dependencies = [ "base64 0.22.1", "bytes", - "headers-core 0.3.0", + "headers-core", "http 1.4.0", "httpdate", "mime", "sha1", ] -[[package]] -name = "headers-core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" -dependencies = [ - "http 0.2.12", -] - [[package]] name = "headers-core" version = "0.3.0" @@ -3742,17 +3693,6 @@ dependencies = [ "itoa", ] -[[package]] -name = "http-body" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" -dependencies = [ - "bytes", - "http 0.2.12", - "pin-project-lite", -] - [[package]] name = "http-body" version = "1.0.1" @@ -3772,7 +3712,7 @@ dependencies = [ "bytes", "futures-core", "http 1.4.0", - "http-body 1.0.1", + "http-body", "pin-project-lite", ] @@ -3806,30 +3746,6 @@ dependencies = [ "typenum", ] -[[package]] -name = "hyper" -version = "0.14.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2 0.3.27", - "http 0.2.12", - "http-body 0.4.6", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2 0.5.10", - "tokio", - "tower-service", - "tracing", - "want", -] - [[package]] name = "hyper" version = "1.9.0" @@ -3840,9 +3756,9 @@ dependencies = [ "bytes", "futures-channel", "futures-core", - "h2 0.4.12", + "h2", "http 1.4.0", - "http-body 1.0.1", + "http-body", "httparse", "httpdate", "itoa", @@ -3859,7 +3775,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2b52f86d1d4bc0d6b4e6826d960b1b333217e07d36b882dca570a5e1c48895b" dependencies = [ "http 1.4.0", - "hyper 1.9.0", + "hyper", "hyper-util", "log", "rustls", @@ -3880,13 +3796,13 @@ dependencies = [ "futures-channel", "futures-util", "http 1.4.0", - "http-body 1.0.1", - "hyper 1.9.0", + "http-body", + "hyper", "ipnet", "libc", "percent-encoding", "pin-project-lite", - "socket2 0.6.3", + "socket2", "tokio", "tower-service", "tracing", @@ -7544,7 +7460,7 @@ dependencies = [ "gstreamer", "http 1.4.0", "http-body-util", - "hyper 1.9.0", + "hyper", "image", "ipc-channel", "itertools 0.14.0", @@ -7870,7 +7786,7 @@ dependencies = [ "base64 0.22.1", "chrono", "crossbeam-channel", - "headers 0.4.1", + "headers", "http 1.4.0", "log", "malloc_size_of_derive", @@ -8042,9 +7958,9 @@ name = "servo-hyper-serde" version = "0.1.0" dependencies = [ "cookie 0.18.1", - "headers 0.4.1", + "headers", "http 1.4.0", - "hyper 1.9.0", + "hyper", "mime", "serde", "serde_bytes", @@ -8451,10 +8367,10 @@ dependencies = [ "futures-core", "futures-util", "generic-array", - "headers 0.4.1", + "headers", "http 1.4.0", "http-body-util", - "hyper 1.9.0", + "hyper", "hyper-rustls", "hyper-util", "imsz", @@ -8512,7 +8428,7 @@ dependencies = [ "cookie 0.18.1", "crossbeam-channel", "data-url", - "headers 0.4.1", + "headers", "http 1.4.0", "hyper-util", "ipc-channel", @@ -8705,7 +8621,7 @@ dependencies = [ "euclid", "flate2", "glow 0.17.0", - "headers 0.4.1", + "headers", "hkdf", "html5ever", "http 1.4.0", @@ -8967,10 +8883,12 @@ name = "servo-webdriver-server" version = "0.1.0" dependencies = [ "base64 0.22.1", + "bytes", "cookie 0.18.1", "crossbeam-channel", "euclid", "http 0.2.12", + "http 1.4.0", "image", "keyboard-types", "log", @@ -8984,7 +8902,10 @@ dependencies = [ "servo-url", "stylo_traits", "time", + "tokio", + "url", "uuid", + "warp", "webdriver", ] @@ -9118,7 +9039,7 @@ dependencies = [ "euclid", "gilrs", "glow 0.17.0", - "headers 0.4.1", + "headers", "hilog", "hitrace", "image", @@ -9357,16 +9278,6 @@ dependencies = [ "serde", ] -[[package]] -name = "socket2" -version = "0.5.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" -dependencies = [ - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "socket2" version = "0.6.3" @@ -10074,7 +9985,7 @@ dependencies = [ "libc", "mio", "pin-project-lite", - "socket2 0.6.3", + "socket2", "tokio-macros", "windows-sys 0.61.2", ] @@ -10754,16 +10665,18 @@ dependencies = [ [[package]] name = "warp" -version = "0.3.7" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c" +checksum = "51d06d9202adc1f15d709c4f4a2069be5428aa912cc025d6f268ac441ab066b0" dependencies = [ "bytes", - "futures-channel", "futures-util", - "headers 0.3.9", - "http 0.2.12", - "hyper 0.14.32", + "headers", + "http 1.4.0", + "http-body", + "http-body-util", + "hyper", + "hyper-util", "log", "mime", "mime_guess", @@ -11056,10 +10969,7 @@ dependencies = [ "serde_json", "thiserror 1.0.69", "time", - "tokio", - "tokio-stream", "url", - "warp", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 30bb3af0b76..67ff60d0ba6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ bitflags = "2.11" brotli = "8.0.2" bytemuck = "1" byteorder = "1.5" +bytes = "1.0" cbc = "0.1.2" cfg-if = "1.0.4" chacha20poly1305 = "0.10" @@ -208,7 +209,7 @@ utf-8 = "0.7" uuid = { version = "1.23.0", features = ["v4", "v5"] } vello = "0.6" vello_cpu = "0.0.4" -webdriver = "0.53.0" +webdriver = { version = "0.53.0", default-features = false } webpki-roots = "1.0" webrender = { version = "0.68", features = ["capture"] } webrender_api = "0.68" diff --git a/components/net/Cargo.toml b/components/net/Cargo.toml index f0c4015fe11..bb87f6e5824 100644 --- a/components/net/Cargo.toml +++ b/components/net/Cargo.toml @@ -76,7 +76,7 @@ servo-url = { workspace = true } sha2 = { workspace = true } tracing = { workspace = true, optional = true } time = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] } +tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync", "fs"] } tokio-rustls = { workspace = true } tokio-stream = { workspace = true } tokio-util = { workspace = true, features = ["codec", "io"] } diff --git a/components/webdriver_server/Cargo.toml b/components/webdriver_server/Cargo.toml index 29533c5de2b..6b9e856be2f 100644 --- a/components/webdriver_server/Cargo.toml +++ b/components/webdriver_server/Cargo.toml @@ -15,11 +15,13 @@ path = "lib.rs" [dependencies] base64 = { workspace = true } +bytes = { workspace = true } cookie = { workspace = true } crossbeam-channel = { workspace = true } embedder_traits = { workspace = true } euclid = { workspace = true } -http = "0.2" +http = { workspace = true } +http02 = { package = "http", version = "0.2" } image = { workspace = true } keyboard-types = { workspace = true } log = { workspace = true } @@ -32,5 +34,8 @@ servo-geometry = { workspace = true } servo-url = { workspace = true } stylo_traits = { workspace = true } time = { workspace = true } +tokio = { workspace = true } uuid = { workspace = true } webdriver = { workspace = true } +url = { workspace = true } +warp = { version = "0.4.2", features = ["server"] } diff --git a/components/webdriver_server/lib.rs b/components/webdriver_server/lib.rs index 1c7a6bf42f7..7f7c25a978f 100644 --- a/components/webdriver_server/lib.rs +++ b/components/webdriver_server/lib.rs @@ -9,6 +9,7 @@ mod actions; mod capabilities; mod script_argument_extraction; +mod server; mod session; mod timeout; mod user_prompt; @@ -42,6 +43,7 @@ use serde::de::{Deserializer, MapAccess, Visitor}; use serde::ser::Serializer; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use server::{Session, SessionTeardownKind, WebDriverHandler}; use servo_base::generic_channel::{self, GenericReceiver, GenericSender, RoutedReceiver}; use servo_base::id::{BrowsingContextId, WebViewId}; use servo_config::prefs::{self, PrefValue, Preferences}; @@ -71,7 +73,6 @@ use webdriver::response::{ CloseWindowResponse, CookieResponse, CookiesResponse, ElementRectResponse, NewSessionResponse, NewWindowResponse, TimeoutsResponse, ValueResponse, WebDriverResponse, WindowRectResponse, }; -use webdriver::server::{self, Session, SessionTeardownKind, WebDriverHandler}; use crate::actions::{ELEMENT_CLICK_BUTTON, InputSourceState, PendingActions, PointerInputState}; use crate::session::{PageLoadStrategy, WebDriverSession}; diff --git a/components/webdriver_server/server.rs b/components/webdriver_server/server.rs new file mode 100644 index 00000000000..467dc08912c --- /dev/null +++ b/components/webdriver_server/server.rs @@ -0,0 +1,731 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// The code in this file was initially forked from the server module of the webdriver crate. +// https://github.com/mozilla-firefox/firefox/blob/63719d122f9214f37fd1d285a91897b8345b88b0/testing/webdriver/src/server.rs + +use std::marker::PhantomData; +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::sync::mpsc::{Receiver, Sender, channel}; +use std::sync::{Arc, Mutex}; +use std::thread; + +use bytes::Bytes; +use http::{Method, StatusCode}; +use log::{debug, error, trace, warn}; +use tokio::net::TcpListener; +use url::{Host, Url}; +use warp::{Buf, Filter, Rejection}; +use webdriver::command::{WebDriverCommand, WebDriverMessage}; +use webdriver::error::{ErrorStatus, WebDriverError, WebDriverResult}; +use webdriver::httpapi::{ + Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute, standard_routes, +}; +use webdriver::response::{CloseWindowResponse, WebDriverResponse}; + +use crate::Parameters; + +// Silence warning about Quit being unused for now. +#[allow(dead_code)] +enum DispatchMessage { + HandleWebDriver( + WebDriverMessage, + Sender>, + ), + Quit, +} + +#[derive(Clone, Debug, PartialEq)] +/// Representation of whether we managed to successfully send a DeleteSession message +/// and read the response during session teardown. +pub enum SessionTeardownKind { + /// A DeleteSession message has been sent and the response handled. + Deleted, + /// No DeleteSession message has been sent, or the response was not received. + NotDeleted, +} + +#[derive(Clone, Debug, PartialEq)] +pub struct Session { + pub id: String, +} + +impl Session { + fn new(id: String) -> Session { + Session { id } + } +} + +pub trait WebDriverHandler: Send { + fn handle_command( + &mut self, + session: &Option, + msg: WebDriverMessage, + ) -> WebDriverResult; + fn teardown_session(&mut self, kind: SessionTeardownKind); +} + +#[derive(Debug)] +struct Dispatcher, U: WebDriverExtensionRoute> { + handler: T, + session: Option, + extension_type: PhantomData, +} + +impl, U: WebDriverExtensionRoute> Dispatcher { + fn new(handler: T) -> Dispatcher { + Dispatcher { + handler, + session: None, + extension_type: PhantomData, + } + } + + fn run(&mut self, msg_chan: &Receiver>) { + loop { + match msg_chan.recv() { + Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => { + let resp = match self.check_session(&msg) { + Ok(_) => self.handler.handle_command(&self.session, msg), + Err(e) => Err(e), + }; + + match resp { + Ok(WebDriverResponse::NewSession(ref new_session)) => { + self.session = Some(Session::new(new_session.session_id.clone())); + }, + Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => { + if handles.is_empty() { + debug!("Last window was closed, deleting session"); + // The teardown_session implementation is responsible for actually + // sending the DeleteSession message in this case + self.teardown_session(SessionTeardownKind::NotDeleted); + } + }, + Ok(WebDriverResponse::DeleteSession) => { + self.teardown_session(SessionTeardownKind::Deleted); + }, + Err(ref x) if x.delete_session => { + // This includes the case where we failed during session creation + self.teardown_session(SessionTeardownKind::NotDeleted) + }, + _ => {}, + } + + if resp_chan.send(resp).is_err() { + error!("Sending response to the main thread failed"); + }; + }, + Ok(DispatchMessage::Quit) => break, + Err(e) => panic!("Error receiving message in handler: {:?}", e), + } + } + } + + fn teardown_session(&mut self, kind: SessionTeardownKind) { + debug!("Teardown session"); + let final_kind = match kind { + SessionTeardownKind::NotDeleted if self.session.is_some() => { + let delete_session = WebDriverMessage { + session_id: Some( + self.session + .as_ref() + .expect("Failed to get session") + .id + .clone(), + ), + command: WebDriverCommand::DeleteSession, + }; + match self.handler.handle_command(&self.session, delete_session) { + Ok(_) => SessionTeardownKind::Deleted, + Err(_) => SessionTeardownKind::NotDeleted, + } + }, + _ => kind, + }; + self.handler.teardown_session(final_kind); + self.session = None; + } + + fn check_session(&self, msg: &WebDriverMessage) -> WebDriverResult<()> { + match msg.session_id { + Some(ref msg_session_id) => match self.session { + Some(ref existing_session) => { + if existing_session.id != *msg_session_id { + Err(WebDriverError::new( + ErrorStatus::InvalidSessionId, + format!("Got unexpected session id {}", msg_session_id), + )) + } else { + Ok(()) + } + }, + None => Ok(()), + }, + None => { + match self.session { + Some(_) => { + match msg.command { + WebDriverCommand::Status => Ok(()), + WebDriverCommand::NewSession(_) => Err(WebDriverError::new( + ErrorStatus::SessionNotCreated, + "Session is already started", + )), + _ => { + // This should be impossible + error!("Got a message with no session id"); + Err(WebDriverError::new( + ErrorStatus::UnknownError, + "Got a command with no session?!", + )) + }, + } + }, + None => match msg.command { + WebDriverCommand::NewSession(_) => Ok(()), + WebDriverCommand::Status => Ok(()), + _ => Err(WebDriverError::new( + ErrorStatus::InvalidSessionId, + "Tried to run a command before creating a session", + )), + }, + } + }, + } + } +} + +pub struct Listener { + guard: Option>, + pub socket: SocketAddr, +} + +impl Drop for Listener { + fn drop(&mut self) { + let _ = self.guard.take().map(|j| j.join()); + } +} + +pub fn start( + mut address: SocketAddr, + allow_hosts: Vec, + allow_origins: Vec, + handler: T, + extension_routes: Vec<(Method, &'static str, U)>, +) -> ::std::io::Result +where + T: 'static + WebDriverHandler, + U: 'static + WebDriverExtensionRoute + Send + Sync, +{ + let listener = StdTcpListener::bind(address)?; + listener.set_nonblocking(true)?; + let addr = listener.local_addr()?; + if address.port() == 0 { + // If we passed in 0 as the port number the OS will assign an unused port; + // we want to update the address to the actual used port + address.set_port(addr.port()) + } + let (msg_send, msg_recv) = channel(); + + let builder = thread::Builder::new().name("webdriver server".to_string()); + let handle = builder.spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .build() + .unwrap(); + let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() }); + let wroutes = build_warp_routes( + address, + allow_hosts, + allow_origins, + &extension_routes, + msg_send.clone(), + ); + let fut = warp::serve(wroutes).incoming(listener).run(); + rt.block_on(fut); + })?; + + let builder = thread::Builder::new().name("webdriver dispatcher".to_string()); + builder.spawn(move || { + let mut dispatcher = Dispatcher::new(handler); + dispatcher.run(&msg_recv); + })?; + + Ok(Listener { + guard: Some(handle), + socket: addr, + }) +} + +fn build_warp_routes( + address: SocketAddr, + allow_hosts: Vec, + allow_origins: Vec, + ext_routes: &[(Method, &'static str, U)], + chan: Sender>, +) -> impl Filter + Clone + 'static { + let chan = Arc::new(Mutex::new(chan)); + let mut std_routes = standard_routes::(); + + let (method, path, res) = std_routes.pop().unwrap(); + trace!("Build standard route for {path}"); + let mut wroutes = build_route( + address, + allow_hosts.clone(), + allow_origins.clone(), + convert_method(method), + path, + res, + chan.clone(), + ); + + for (method, path, res) in std_routes { + trace!("Build standard route for {path}"); + wroutes = wroutes + .or(build_route( + address, + allow_hosts.clone(), + allow_origins.clone(), + convert_method(method), + path, + res.clone(), + chan.clone(), + )) + .unify() + .boxed() + } + + for (method, path, res) in ext_routes { + trace!("Build vendor route for {path}"); + wroutes = wroutes + .or(build_route( + address, + allow_hosts.clone(), + allow_origins.clone(), + method.clone(), + path, + Route::Extension(res.clone()), + chan.clone(), + )) + .unify() + .boxed() + } + + wroutes +} + +fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool { + // Validate that the Host header value has a hostname in allow_hosts and + // the port matches the server configuration + let header_host_url = match Url::parse(&format!("http://{}", &host_header)) { + Ok(x) => x, + Err(_) => { + return false; + }, + }; + + let host = match header_host_url.host() { + Some(host) => host.to_owned(), + None => { + // This shouldn't be possible since http URL always have a + // host, but conservatively return false here, which will cause + // an error response + return false; + }, + }; + let port = match header_host_url.port_or_known_default() { + Some(port) => port, + None => { + // This shouldn't be possible since http URL always have a + // default port, but conservatively return false here, which will cause + // an error response + return false; + }, + }; + + let host_matches = match host { + Host::Domain(_) => allow_hosts.contains(&host), + Host::Ipv4(_) | Host::Ipv6(_) => true, + }; + let port_matches = server_address.port() == port; + host_matches && port_matches +} + +fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool { + // Validate that the Origin header value is in allow_origins + allow_origins.contains(&origin_url) +} + +fn build_route( + server_address: SocketAddr, + allow_hosts: Vec, + allow_origins: Vec, + method: Method, + path: &'static str, + route: Route, + chan: Arc>>>, +) -> warp::filters::BoxedFilter<(impl warp::Reply,)> { + // Create an empty filter based on the provided method and append an empty hashmap to it. The + // hashmap will be used to store path parameters. + let mut subroute = match method { + Method::GET => warp::get().boxed(), + Method::POST => warp::post().boxed(), + Method::DELETE => warp::delete().boxed(), + Method::OPTIONS => warp::options().boxed(), + Method::PUT => warp::put().boxed(), + _ => panic!("Unsupported method"), + } + .or(warp::head()) + .unify() + .map(Parameters::new) + .boxed(); + + // For each part of the path, if it's a normal part, just append it to the current filter, + // otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert + // it into the hashmap created earlier. + for part in path.split('/') { + if part.is_empty() { + continue; + } else if part.starts_with('{') { + assert!(part.ends_with('}')); + + subroute = subroute + .and(warp::path::param()) + .map(move |mut params: Parameters, param: String| { + let name = &part[1..part.len() - 1]; + params.insert(name.to_string(), param); + params + }) + .boxed(); + } else { + subroute = subroute.and(warp::path(part)).boxed(); + } + } + + // Finally, tell warp that the path is complete + subroute + .and(warp::path::end()) + .and(warp::path::full()) + .and(warp::method()) + .and(warp::header::optional::("origin")) + .and(warp::header::optional::("host")) + .and(warp::header::optional::("content-type")) + .and(warp::body::bytes()) + .map( + move |params, + full_path: warp::path::FullPath, + method, + origin_header: Option, + host_header: Option, + content_type_header: Option, + body: Bytes| { + if method == Method::HEAD { + return warp::reply::with_status("".into(), StatusCode::OK); + } + if let Some(host) = host_header { + if !is_host_allowed(&server_address, &allow_hosts, &host) { + warn!( + "Rejected request with Host header {}, allowed values are [{}]", + host, + allow_hosts + .iter() + .map(|x| format!("{}:{}", x, server_address.port())) + .collect::>() + .join(",") + ); + let err = WebDriverError::new( + ErrorStatus::UnknownError, + format!("Invalid Host header {}", host), + ); + return warp::reply::with_status( + serde_json::to_string(&err).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + }; + } else { + warn!("Rejected request with missing Host header"); + let err = WebDriverError::new( + ErrorStatus::UnknownError, + "Missing Host header".to_string(), + ); + return warp::reply::with_status( + serde_json::to_string(&err).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + } + if let Some(origin) = origin_header { + let make_err = || { + warn!( + "Rejected request with Origin header {}, allowed values are [{}]", + origin, + allow_origins + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") + ); + WebDriverError::new( + ErrorStatus::UnknownError, + format!("Invalid Origin header {}", origin), + ) + }; + let origin_url = match Url::parse(&origin) { + Ok(url) => url, + Err(_) => { + return warp::reply::with_status( + serde_json::to_string(&make_err()).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + }, + }; + if !is_origin_allowed(&allow_origins, origin_url) { + return warp::reply::with_status( + serde_json::to_string(&make_err()).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + } + } + if method == Method::POST { + // Disallow CORS-safelisted request headers + // c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header + let content_type = content_type_header + .as_ref() + .map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x)) + .map(|x| x.trim()) + .map(|x| x.to_lowercase()); + match content_type.as_ref().map(|x| x.as_ref()) { + Some("application/x-www-form-urlencoded") | + Some("multipart/form-data") | + Some("text/plain") => { + warn!( + "Rejected POST request with disallowed content type {}", + content_type.unwrap_or_else(|| "".into()) + ); + let err = WebDriverError::new( + ErrorStatus::UnknownError, + "Invalid Content-Type", + ); + return warp::reply::with_status( + serde_json::to_string(&err).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + }, + Some(_) | None => {}, + } + } + let body = String::from_utf8(body.chunk().to_vec()); + if body.is_err() { + let err = WebDriverError::new( + ErrorStatus::UnknownError, + "Request body wasn't valid UTF-8", + ); + return warp::reply::with_status( + serde_json::to_string(&err).unwrap(), + StatusCode::INTERNAL_SERVER_ERROR, + ); + } + let body = body.unwrap(); + + debug!("-> {} {} {}", method, full_path.as_str(), body); + let msg_result = WebDriverMessage::from_http( + route.clone(), + ¶ms, + &body, + method == Method::POST, + ); + + let (status, resp_body) = match msg_result { + Ok(message) => { + let (send_res, recv_res) = channel(); + match chan.lock() { + Ok(ref c) => { + let res = + c.send(DispatchMessage::HandleWebDriver(message, send_res)); + match res { + Ok(x) => x, + Err(e) => panic!("Error: {:?}", e), + } + }, + Err(e) => panic!("Error reading response: {:?}", e), + } + + match recv_res.recv() { + Ok(data) => match data { + Ok(response) => { + (StatusCode::OK, serde_json::to_string(&response).unwrap()) + }, + Err(e) => ( + StatusCode::from_u16(e.http_status().as_u16()).unwrap(), + serde_json::to_string(&e).unwrap(), + ), + }, + Err(e) => panic!("Error reading response: {:?}", e), + } + }, + Err(e) => ( + convert_status(e.http_status()), + serde_json::to_string(&e).unwrap(), + ), + }; + + debug!("<- {} {}", status, resp_body); + warp::reply::with_status(resp_body, status) + }, + ) + .with(warp::reply::with::header( + http::header::CONTENT_TYPE, + "application/json; charset=utf-8", + )) + .with(warp::reply::with::header( + http::header::CACHE_CONTROL, + "no-cache", + )) + .boxed() +} + +/// Convert from http 0.2 StatusCode to http 1.0 StatusCode +fn convert_status(status: http02::StatusCode) -> StatusCode { + StatusCode::from_u16(status.as_u16()).unwrap() +} + +/// Convert from http 0.2 Method to http 1.0 Method +fn convert_method(method: http02::Method) -> Method { + match method { + http02::Method::OPTIONS => http::Method::OPTIONS, + http02::Method::GET => http::Method::GET, + http02::Method::POST => http::Method::POST, + http02::Method::PUT => http::Method::PUT, + http02::Method::DELETE => http::Method::DELETE, + http02::Method::HEAD => http::Method::HEAD, + http02::Method::TRACE => http::Method::TRACE, + http02::Method::CONNECT => http::Method::CONNECT, + http02::Method::PATCH => http::Method::PATCH, + _ => http::Method::from_bytes(method.as_str().as_bytes()).unwrap(), + } +} + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + use std::str::FromStr; + + use super::*; + + #[test] + fn test_host_allowed() { + let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); + let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000); + let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80); + let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000); + + // We match the host ip address to the server, so we can only use hosts that actually resolve + let localhost_host = Host::Domain("localhost".to_string()); + let test_host = Host::Domain("example.test".to_string()); + let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string()); + + assert!(is_host_allowed( + &addr_80, + &[localhost_host.clone()], + "localhost:80" + )); + assert!(is_host_allowed( + &addr_80, + &[test_host.clone()], + "example.test:80" + )); + assert!(is_host_allowed( + &addr_80, + &[test_host.clone(), localhost_host.clone()], + "example.test" + )); + assert!(is_host_allowed( + &addr_80, + &[subdomain_localhost_host.clone()], + "subdomain.localhost" + )); + + // ip address cases + assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80")); + assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1")); + assert!(is_host_allowed(&addr_80, &[], "[::1]")); + assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000")); + assert!(is_host_allowed( + &addr_80, + &[subdomain_localhost_host.clone()], + "[::1]" + )); + assert!(is_host_allowed( + &addr_v6_8000, + &[subdomain_localhost_host.clone()], + "[::1]:8000" + )); + + // Mismatch cases + + assert!(!is_host_allowed(&addr_80, &[test_host], "localhost")); + + assert!(!is_host_allowed(&addr_80, &[], "localhost:80")); + + // Port mismatch cases + + assert!(!is_host_allowed( + &addr_80, + &[localhost_host.clone()], + "localhost:8000" + )); + assert!(!is_host_allowed( + &addr_8000, + &[localhost_host.clone()], + "localhost" + )); + assert!(!is_host_allowed( + &addr_v6_8000, + &[localhost_host.clone()], + "[::1]" + )); + } + + #[test] + fn test_origin_allowed() { + assert!(is_origin_allowed( + &[Url::parse("http://localhost").unwrap()], + Url::parse("http://localhost").unwrap() + )); + assert!(is_origin_allowed( + &[Url::parse("http://localhost").unwrap()], + Url::parse("http://localhost:80").unwrap() + )); + assert!(is_origin_allowed( + &[ + Url::parse("https://test.example").unwrap(), + Url::parse("http://localhost").unwrap() + ], + Url::parse("http://localhost").unwrap() + )); + assert!(is_origin_allowed( + &[ + Url::parse("https://test.example").unwrap(), + Url::parse("http://localhost").unwrap() + ], + Url::parse("https://test.example:443").unwrap() + )); + // Mismatch cases + assert!(!is_origin_allowed( + &[], + Url::parse("http://localhost").unwrap() + )); + assert!(!is_origin_allowed( + &[Url::parse("http://localhost").unwrap()], + Url::parse("http://localhost:8000").unwrap() + )); + assert!(!is_origin_allowed( + &[Url::parse("https://localhost").unwrap()], + Url::parse("http://localhost").unwrap() + )); + assert!(!is_origin_allowed( + &[Url::parse("https://example.test").unwrap()], + Url::parse("http://subdomain.example.test").unwrap() + )); + } +} diff --git a/deny.toml b/deny.toml index 85b35987b6a..0573f66ece5 100644 --- a/deny.toml +++ b/deny.toml @@ -153,13 +153,7 @@ skip = [ "thiserror-impl", # duplicated by webdriver - "h2", - "headers", - "headers-core", "http", - "http-body", - "hyper", - "socket2", # duplicated by winit "block2",