mirror of
https://github.com/servo/servo
synced 2026-04-25 17:15:48 +02:00
Use our own http server for the webdriver server (#44338)
Helps with: https://github.com/servo/servo/issues/38776. Reduces total Servo crate count by 7 (977 -> 970). This PR simply: - Disables the `server` feature in the `webdriver` crate - Vendors the implementation of the server from the `webdriver` crate - Updates dependencies + fixes code to work with new versions Unfortunately `webdriver` depends on `http` even with the `server` feature disabled, so we still end up with duplicate versions of `http`. But at least the duplicate `hyper` is eliminated. Future work could change the implementation to e.g. move away from `warp` or similar. Testing: WPT tests use webdriver, so this should be exercised heavily by those tests. --------- Signed-off-by: Nico Burns <nico@nicoburns.com>
This commit is contained in:
152
Cargo.lock
generated
152
Cargo.lock
generated
@@ -749,12 +749,6 @@ version = "0.12.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
@@ -3463,25 +3457,6 @@ dependencies = [
|
||||
"svg_fmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.3.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0beca50380b1fc32983fc1cb4587bfa4bb9e78fc259aad4a0032d2080309222d"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"fnv",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"futures-util",
|
||||
"http 0.2.12",
|
||||
"indexmap",
|
||||
"slab",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "h2"
|
||||
version = "0.4.12"
|
||||
@@ -3572,21 +3547,6 @@ dependencies = [
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
"headers-core 0.2.0",
|
||||
"http 0.2.12",
|
||||
"httpdate",
|
||||
"mime",
|
||||
"sha1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.4.1"
|
||||
@@ -3595,22 +3555,13 @@ checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"headers-core 0.3.0",
|
||||
"headers-core",
|
||||
"http 1.4.0",
|
||||
"httpdate",
|
||||
"mime",
|
||||
"sha1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers-core"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
|
||||
dependencies = [
|
||||
"http 0.2.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers-core"
|
||||
version = "0.3.0"
|
||||
@@ -3742,17 +3693,6 @@ dependencies = [
|
||||
"itoa",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "0.4.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"http 0.2.12",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "http-body"
|
||||
version = "1.0.1"
|
||||
@@ -3772,7 +3712,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-core",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -3806,30 +3746,6 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "0.14.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"h2 0.3.27",
|
||||
"http 0.2.12",
|
||||
"http-body 0.4.6",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"socket2 0.5.10",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
"want",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hyper"
|
||||
version = "1.9.0"
|
||||
@@ -3840,9 +3756,9 @@ dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"h2 0.4.12",
|
||||
"h2",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"http-body",
|
||||
"httparse",
|
||||
"httpdate",
|
||||
"itoa",
|
||||
@@ -3859,7 +3775,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2b52f86d1d4bc0d6b4e6826d960b1b333217e07d36b882dca570a5e1c48895b"
|
||||
dependencies = [
|
||||
"http 1.4.0",
|
||||
"hyper 1.9.0",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"rustls",
|
||||
@@ -3880,13 +3796,13 @@ dependencies = [
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"http 1.4.0",
|
||||
"http-body 1.0.1",
|
||||
"hyper 1.9.0",
|
||||
"http-body",
|
||||
"hyper",
|
||||
"ipnet",
|
||||
"libc",
|
||||
"percent-encoding",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.3",
|
||||
"socket2",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@@ -7544,7 +7460,7 @@ dependencies = [
|
||||
"gstreamer",
|
||||
"http 1.4.0",
|
||||
"http-body-util",
|
||||
"hyper 1.9.0",
|
||||
"hyper",
|
||||
"image",
|
||||
"ipc-channel",
|
||||
"itertools 0.14.0",
|
||||
@@ -7870,7 +7786,7 @@ dependencies = [
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"crossbeam-channel",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"http 1.4.0",
|
||||
"log",
|
||||
"malloc_size_of_derive",
|
||||
@@ -8042,9 +7958,9 @@ name = "servo-hyper-serde"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cookie 0.18.1",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"http 1.4.0",
|
||||
"hyper 1.9.0",
|
||||
"hyper",
|
||||
"mime",
|
||||
"serde",
|
||||
"serde_bytes",
|
||||
@@ -8451,10 +8367,10 @@ dependencies = [
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"generic-array",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"http 1.4.0",
|
||||
"http-body-util",
|
||||
"hyper 1.9.0",
|
||||
"hyper",
|
||||
"hyper-rustls",
|
||||
"hyper-util",
|
||||
"imsz",
|
||||
@@ -8512,7 +8428,7 @@ dependencies = [
|
||||
"cookie 0.18.1",
|
||||
"crossbeam-channel",
|
||||
"data-url",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"http 1.4.0",
|
||||
"hyper-util",
|
||||
"ipc-channel",
|
||||
@@ -8705,7 +8621,7 @@ dependencies = [
|
||||
"euclid",
|
||||
"flate2",
|
||||
"glow 0.17.0",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"hkdf",
|
||||
"html5ever",
|
||||
"http 1.4.0",
|
||||
@@ -8967,10 +8883,12 @@ name = "servo-webdriver-server"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"cookie 0.18.1",
|
||||
"crossbeam-channel",
|
||||
"euclid",
|
||||
"http 0.2.12",
|
||||
"http 1.4.0",
|
||||
"image",
|
||||
"keyboard-types",
|
||||
"log",
|
||||
@@ -8984,7 +8902,10 @@ dependencies = [
|
||||
"servo-url",
|
||||
"stylo_traits",
|
||||
"time",
|
||||
"tokio",
|
||||
"url",
|
||||
"uuid",
|
||||
"warp",
|
||||
"webdriver",
|
||||
]
|
||||
|
||||
@@ -9118,7 +9039,7 @@ dependencies = [
|
||||
"euclid",
|
||||
"gilrs",
|
||||
"glow 0.17.0",
|
||||
"headers 0.4.1",
|
||||
"headers",
|
||||
"hilog",
|
||||
"hitrace",
|
||||
"image",
|
||||
@@ -9357,16 +9278,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.6.3"
|
||||
@@ -10074,7 +9985,7 @@ dependencies = [
|
||||
"libc",
|
||||
"mio",
|
||||
"pin-project-lite",
|
||||
"socket2 0.6.3",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
@@ -10754,16 +10665,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "warp"
|
||||
version = "0.3.7"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c"
|
||||
checksum = "51d06d9202adc1f15d709c4f4a2069be5428aa912cc025d6f268ac441ab066b0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"headers 0.3.9",
|
||||
"http 0.2.12",
|
||||
"hyper 0.14.32",
|
||||
"headers",
|
||||
"http 1.4.0",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"hyper",
|
||||
"hyper-util",
|
||||
"log",
|
||||
"mime",
|
||||
"mime_guess",
|
||||
@@ -11056,10 +10969,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"thiserror 1.0.69",
|
||||
"time",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
"url",
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -46,6 +46,7 @@ bitflags = "2.11"
|
||||
brotli = "8.0.2"
|
||||
bytemuck = "1"
|
||||
byteorder = "1.5"
|
||||
bytes = "1.0"
|
||||
cbc = "0.1.2"
|
||||
cfg-if = "1.0.4"
|
||||
chacha20poly1305 = "0.10"
|
||||
@@ -208,7 +209,7 @@ utf-8 = "0.7"
|
||||
uuid = { version = "1.23.0", features = ["v4", "v5"] }
|
||||
vello = "0.6"
|
||||
vello_cpu = "0.0.4"
|
||||
webdriver = "0.53.0"
|
||||
webdriver = { version = "0.53.0", default-features = false }
|
||||
webpki-roots = "1.0"
|
||||
webrender = { version = "0.68", features = ["capture"] }
|
||||
webrender_api = "0.68"
|
||||
|
||||
@@ -76,7 +76,7 @@ servo-url = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
tracing = { workspace = true, optional = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync"] }
|
||||
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "sync", "fs"] }
|
||||
tokio-rustls = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec", "io"] }
|
||||
|
||||
@@ -15,11 +15,13 @@ path = "lib.rs"
|
||||
|
||||
[dependencies]
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
cookie = { workspace = true }
|
||||
crossbeam-channel = { workspace = true }
|
||||
embedder_traits = { workspace = true }
|
||||
euclid = { workspace = true }
|
||||
http = "0.2"
|
||||
http = { workspace = true }
|
||||
http02 = { package = "http", version = "0.2" }
|
||||
image = { workspace = true }
|
||||
keyboard-types = { workspace = true }
|
||||
log = { workspace = true }
|
||||
@@ -32,5 +34,8 @@ servo-geometry = { workspace = true }
|
||||
servo-url = { workspace = true }
|
||||
stylo_traits = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
webdriver = { workspace = true }
|
||||
url = { workspace = true }
|
||||
warp = { version = "0.4.2", features = ["server"] }
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
mod actions;
|
||||
mod capabilities;
|
||||
mod script_argument_extraction;
|
||||
mod server;
|
||||
mod session;
|
||||
mod timeout;
|
||||
mod user_prompt;
|
||||
@@ -42,6 +43,7 @@ use serde::de::{Deserializer, MapAccess, Visitor};
|
||||
use serde::ser::Serializer;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use server::{Session, SessionTeardownKind, WebDriverHandler};
|
||||
use servo_base::generic_channel::{self, GenericReceiver, GenericSender, RoutedReceiver};
|
||||
use servo_base::id::{BrowsingContextId, WebViewId};
|
||||
use servo_config::prefs::{self, PrefValue, Preferences};
|
||||
@@ -71,7 +73,6 @@ use webdriver::response::{
|
||||
CloseWindowResponse, CookieResponse, CookiesResponse, ElementRectResponse, NewSessionResponse,
|
||||
NewWindowResponse, TimeoutsResponse, ValueResponse, WebDriverResponse, WindowRectResponse,
|
||||
};
|
||||
use webdriver::server::{self, Session, SessionTeardownKind, WebDriverHandler};
|
||||
|
||||
use crate::actions::{ELEMENT_CLICK_BUTTON, InputSourceState, PendingActions, PointerInputState};
|
||||
use crate::session::{PageLoadStrategy, WebDriverSession};
|
||||
|
||||
731
components/webdriver_server/server.rs
Normal file
731
components/webdriver_server/server.rs
Normal file
@@ -0,0 +1,731 @@
|
||||
/* This Source Code Form is subject to the terms of the Mozilla Public
|
||||
* License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
// The code in this file was initially forked from the server module of the webdriver crate.
|
||||
// https://github.com/mozilla-firefox/firefox/blob/63719d122f9214f37fd1d285a91897b8345b88b0/testing/webdriver/src/server.rs
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use std::net::{SocketAddr, TcpListener as StdTcpListener};
|
||||
use std::sync::mpsc::{Receiver, Sender, channel};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
use bytes::Bytes;
|
||||
use http::{Method, StatusCode};
|
||||
use log::{debug, error, trace, warn};
|
||||
use tokio::net::TcpListener;
|
||||
use url::{Host, Url};
|
||||
use warp::{Buf, Filter, Rejection};
|
||||
use webdriver::command::{WebDriverCommand, WebDriverMessage};
|
||||
use webdriver::error::{ErrorStatus, WebDriverError, WebDriverResult};
|
||||
use webdriver::httpapi::{
|
||||
Route, VoidWebDriverExtensionRoute, WebDriverExtensionRoute, standard_routes,
|
||||
};
|
||||
use webdriver::response::{CloseWindowResponse, WebDriverResponse};
|
||||
|
||||
use crate::Parameters;
|
||||
|
||||
// Silence warning about Quit being unused for now.
|
||||
#[allow(dead_code)]
|
||||
enum DispatchMessage<U: WebDriverExtensionRoute> {
|
||||
HandleWebDriver(
|
||||
WebDriverMessage<U>,
|
||||
Sender<WebDriverResult<WebDriverResponse>>,
|
||||
),
|
||||
Quit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
/// Representation of whether we managed to successfully send a DeleteSession message
|
||||
/// and read the response during session teardown.
|
||||
pub enum SessionTeardownKind {
|
||||
/// A DeleteSession message has been sent and the response handled.
|
||||
Deleted,
|
||||
/// No DeleteSession message has been sent, or the response was not received.
|
||||
NotDeleted,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct Session {
|
||||
pub id: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
fn new(id: String) -> Session {
|
||||
Session { id }
|
||||
}
|
||||
}
|
||||
|
||||
pub trait WebDriverHandler<U: WebDriverExtensionRoute = VoidWebDriverExtensionRoute>: Send {
|
||||
fn handle_command(
|
||||
&mut self,
|
||||
session: &Option<Session>,
|
||||
msg: WebDriverMessage<U>,
|
||||
) -> WebDriverResult<WebDriverResponse>;
|
||||
fn teardown_session(&mut self, kind: SessionTeardownKind);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Dispatcher<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> {
|
||||
handler: T,
|
||||
session: Option<Session>,
|
||||
extension_type: PhantomData<U>,
|
||||
}
|
||||
|
||||
impl<T: WebDriverHandler<U>, U: WebDriverExtensionRoute> Dispatcher<T, U> {
|
||||
fn new(handler: T) -> Dispatcher<T, U> {
|
||||
Dispatcher {
|
||||
handler,
|
||||
session: None,
|
||||
extension_type: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, msg_chan: &Receiver<DispatchMessage<U>>) {
|
||||
loop {
|
||||
match msg_chan.recv() {
|
||||
Ok(DispatchMessage::HandleWebDriver(msg, resp_chan)) => {
|
||||
let resp = match self.check_session(&msg) {
|
||||
Ok(_) => self.handler.handle_command(&self.session, msg),
|
||||
Err(e) => Err(e),
|
||||
};
|
||||
|
||||
match resp {
|
||||
Ok(WebDriverResponse::NewSession(ref new_session)) => {
|
||||
self.session = Some(Session::new(new_session.session_id.clone()));
|
||||
},
|
||||
Ok(WebDriverResponse::CloseWindow(CloseWindowResponse(ref handles))) => {
|
||||
if handles.is_empty() {
|
||||
debug!("Last window was closed, deleting session");
|
||||
// The teardown_session implementation is responsible for actually
|
||||
// sending the DeleteSession message in this case
|
||||
self.teardown_session(SessionTeardownKind::NotDeleted);
|
||||
}
|
||||
},
|
||||
Ok(WebDriverResponse::DeleteSession) => {
|
||||
self.teardown_session(SessionTeardownKind::Deleted);
|
||||
},
|
||||
Err(ref x) if x.delete_session => {
|
||||
// This includes the case where we failed during session creation
|
||||
self.teardown_session(SessionTeardownKind::NotDeleted)
|
||||
},
|
||||
_ => {},
|
||||
}
|
||||
|
||||
if resp_chan.send(resp).is_err() {
|
||||
error!("Sending response to the main thread failed");
|
||||
};
|
||||
},
|
||||
Ok(DispatchMessage::Quit) => break,
|
||||
Err(e) => panic!("Error receiving message in handler: {:?}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn teardown_session(&mut self, kind: SessionTeardownKind) {
|
||||
debug!("Teardown session");
|
||||
let final_kind = match kind {
|
||||
SessionTeardownKind::NotDeleted if self.session.is_some() => {
|
||||
let delete_session = WebDriverMessage {
|
||||
session_id: Some(
|
||||
self.session
|
||||
.as_ref()
|
||||
.expect("Failed to get session")
|
||||
.id
|
||||
.clone(),
|
||||
),
|
||||
command: WebDriverCommand::DeleteSession,
|
||||
};
|
||||
match self.handler.handle_command(&self.session, delete_session) {
|
||||
Ok(_) => SessionTeardownKind::Deleted,
|
||||
Err(_) => SessionTeardownKind::NotDeleted,
|
||||
}
|
||||
},
|
||||
_ => kind,
|
||||
};
|
||||
self.handler.teardown_session(final_kind);
|
||||
self.session = None;
|
||||
}
|
||||
|
||||
fn check_session(&self, msg: &WebDriverMessage<U>) -> WebDriverResult<()> {
|
||||
match msg.session_id {
|
||||
Some(ref msg_session_id) => match self.session {
|
||||
Some(ref existing_session) => {
|
||||
if existing_session.id != *msg_session_id {
|
||||
Err(WebDriverError::new(
|
||||
ErrorStatus::InvalidSessionId,
|
||||
format!("Got unexpected session id {}", msg_session_id),
|
||||
))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
None => Ok(()),
|
||||
},
|
||||
None => {
|
||||
match self.session {
|
||||
Some(_) => {
|
||||
match msg.command {
|
||||
WebDriverCommand::Status => Ok(()),
|
||||
WebDriverCommand::NewSession(_) => Err(WebDriverError::new(
|
||||
ErrorStatus::SessionNotCreated,
|
||||
"Session is already started",
|
||||
)),
|
||||
_ => {
|
||||
// This should be impossible
|
||||
error!("Got a message with no session id");
|
||||
Err(WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
"Got a command with no session?!",
|
||||
))
|
||||
},
|
||||
}
|
||||
},
|
||||
None => match msg.command {
|
||||
WebDriverCommand::NewSession(_) => Ok(()),
|
||||
WebDriverCommand::Status => Ok(()),
|
||||
_ => Err(WebDriverError::new(
|
||||
ErrorStatus::InvalidSessionId,
|
||||
"Tried to run a command before creating a session",
|
||||
)),
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Listener {
|
||||
guard: Option<thread::JoinHandle<()>>,
|
||||
pub socket: SocketAddr,
|
||||
}
|
||||
|
||||
impl Drop for Listener {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.guard.take().map(|j| j.join());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start<T, U>(
|
||||
mut address: SocketAddr,
|
||||
allow_hosts: Vec<Host>,
|
||||
allow_origins: Vec<Url>,
|
||||
handler: T,
|
||||
extension_routes: Vec<(Method, &'static str, U)>,
|
||||
) -> ::std::io::Result<Listener>
|
||||
where
|
||||
T: 'static + WebDriverHandler<U>,
|
||||
U: 'static + WebDriverExtensionRoute + Send + Sync,
|
||||
{
|
||||
let listener = StdTcpListener::bind(address)?;
|
||||
listener.set_nonblocking(true)?;
|
||||
let addr = listener.local_addr()?;
|
||||
if address.port() == 0 {
|
||||
// If we passed in 0 as the port number the OS will assign an unused port;
|
||||
// we want to update the address to the actual used port
|
||||
address.set_port(addr.port())
|
||||
}
|
||||
let (msg_send, msg_recv) = channel();
|
||||
|
||||
let builder = thread::Builder::new().name("webdriver server".to_string());
|
||||
let handle = builder.spawn(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.build()
|
||||
.unwrap();
|
||||
let listener = rt.block_on(async { TcpListener::from_std(listener).unwrap() });
|
||||
let wroutes = build_warp_routes(
|
||||
address,
|
||||
allow_hosts,
|
||||
allow_origins,
|
||||
&extension_routes,
|
||||
msg_send.clone(),
|
||||
);
|
||||
let fut = warp::serve(wroutes).incoming(listener).run();
|
||||
rt.block_on(fut);
|
||||
})?;
|
||||
|
||||
let builder = thread::Builder::new().name("webdriver dispatcher".to_string());
|
||||
builder.spawn(move || {
|
||||
let mut dispatcher = Dispatcher::new(handler);
|
||||
dispatcher.run(&msg_recv);
|
||||
})?;
|
||||
|
||||
Ok(Listener {
|
||||
guard: Some(handle),
|
||||
socket: addr,
|
||||
})
|
||||
}
|
||||
|
||||
fn build_warp_routes<U: 'static + WebDriverExtensionRoute + Send + Sync>(
|
||||
address: SocketAddr,
|
||||
allow_hosts: Vec<Host>,
|
||||
allow_origins: Vec<Url>,
|
||||
ext_routes: &[(Method, &'static str, U)],
|
||||
chan: Sender<DispatchMessage<U>>,
|
||||
) -> impl Filter<Extract = (impl warp::Reply,), Error = Rejection> + Clone + 'static {
|
||||
let chan = Arc::new(Mutex::new(chan));
|
||||
let mut std_routes = standard_routes::<U>();
|
||||
|
||||
let (method, path, res) = std_routes.pop().unwrap();
|
||||
trace!("Build standard route for {path}");
|
||||
let mut wroutes = build_route(
|
||||
address,
|
||||
allow_hosts.clone(),
|
||||
allow_origins.clone(),
|
||||
convert_method(method),
|
||||
path,
|
||||
res,
|
||||
chan.clone(),
|
||||
);
|
||||
|
||||
for (method, path, res) in std_routes {
|
||||
trace!("Build standard route for {path}");
|
||||
wroutes = wroutes
|
||||
.or(build_route(
|
||||
address,
|
||||
allow_hosts.clone(),
|
||||
allow_origins.clone(),
|
||||
convert_method(method),
|
||||
path,
|
||||
res.clone(),
|
||||
chan.clone(),
|
||||
))
|
||||
.unify()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
for (method, path, res) in ext_routes {
|
||||
trace!("Build vendor route for {path}");
|
||||
wroutes = wroutes
|
||||
.or(build_route(
|
||||
address,
|
||||
allow_hosts.clone(),
|
||||
allow_origins.clone(),
|
||||
method.clone(),
|
||||
path,
|
||||
Route::Extension(res.clone()),
|
||||
chan.clone(),
|
||||
))
|
||||
.unify()
|
||||
.boxed()
|
||||
}
|
||||
|
||||
wroutes
|
||||
}
|
||||
|
||||
fn is_host_allowed(server_address: &SocketAddr, allow_hosts: &[Host], host_header: &str) -> bool {
|
||||
// Validate that the Host header value has a hostname in allow_hosts and
|
||||
// the port matches the server configuration
|
||||
let header_host_url = match Url::parse(&format!("http://{}", &host_header)) {
|
||||
Ok(x) => x,
|
||||
Err(_) => {
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let host = match header_host_url.host() {
|
||||
Some(host) => host.to_owned(),
|
||||
None => {
|
||||
// This shouldn't be possible since http URL always have a
|
||||
// host, but conservatively return false here, which will cause
|
||||
// an error response
|
||||
return false;
|
||||
},
|
||||
};
|
||||
let port = match header_host_url.port_or_known_default() {
|
||||
Some(port) => port,
|
||||
None => {
|
||||
// This shouldn't be possible since http URL always have a
|
||||
// default port, but conservatively return false here, which will cause
|
||||
// an error response
|
||||
return false;
|
||||
},
|
||||
};
|
||||
|
||||
let host_matches = match host {
|
||||
Host::Domain(_) => allow_hosts.contains(&host),
|
||||
Host::Ipv4(_) | Host::Ipv6(_) => true,
|
||||
};
|
||||
let port_matches = server_address.port() == port;
|
||||
host_matches && port_matches
|
||||
}
|
||||
|
||||
fn is_origin_allowed(allow_origins: &[Url], origin_url: Url) -> bool {
|
||||
// Validate that the Origin header value is in allow_origins
|
||||
allow_origins.contains(&origin_url)
|
||||
}
|
||||
|
||||
fn build_route<U: 'static + WebDriverExtensionRoute + Send + Sync>(
|
||||
server_address: SocketAddr,
|
||||
allow_hosts: Vec<Host>,
|
||||
allow_origins: Vec<Url>,
|
||||
method: Method,
|
||||
path: &'static str,
|
||||
route: Route<U>,
|
||||
chan: Arc<Mutex<Sender<DispatchMessage<U>>>>,
|
||||
) -> warp::filters::BoxedFilter<(impl warp::Reply,)> {
|
||||
// Create an empty filter based on the provided method and append an empty hashmap to it. The
|
||||
// hashmap will be used to store path parameters.
|
||||
let mut subroute = match method {
|
||||
Method::GET => warp::get().boxed(),
|
||||
Method::POST => warp::post().boxed(),
|
||||
Method::DELETE => warp::delete().boxed(),
|
||||
Method::OPTIONS => warp::options().boxed(),
|
||||
Method::PUT => warp::put().boxed(),
|
||||
_ => panic!("Unsupported method"),
|
||||
}
|
||||
.or(warp::head())
|
||||
.unify()
|
||||
.map(Parameters::new)
|
||||
.boxed();
|
||||
|
||||
// For each part of the path, if it's a normal part, just append it to the current filter,
|
||||
// otherwise if it's a parameter (a named enclosed in { }), we take that parameter and insert
|
||||
// it into the hashmap created earlier.
|
||||
for part in path.split('/') {
|
||||
if part.is_empty() {
|
||||
continue;
|
||||
} else if part.starts_with('{') {
|
||||
assert!(part.ends_with('}'));
|
||||
|
||||
subroute = subroute
|
||||
.and(warp::path::param())
|
||||
.map(move |mut params: Parameters, param: String| {
|
||||
let name = &part[1..part.len() - 1];
|
||||
params.insert(name.to_string(), param);
|
||||
params
|
||||
})
|
||||
.boxed();
|
||||
} else {
|
||||
subroute = subroute.and(warp::path(part)).boxed();
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, tell warp that the path is complete
|
||||
subroute
|
||||
.and(warp::path::end())
|
||||
.and(warp::path::full())
|
||||
.and(warp::method())
|
||||
.and(warp::header::optional::<String>("origin"))
|
||||
.and(warp::header::optional::<String>("host"))
|
||||
.and(warp::header::optional::<String>("content-type"))
|
||||
.and(warp::body::bytes())
|
||||
.map(
|
||||
move |params,
|
||||
full_path: warp::path::FullPath,
|
||||
method,
|
||||
origin_header: Option<String>,
|
||||
host_header: Option<String>,
|
||||
content_type_header: Option<String>,
|
||||
body: Bytes| {
|
||||
if method == Method::HEAD {
|
||||
return warp::reply::with_status("".into(), StatusCode::OK);
|
||||
}
|
||||
if let Some(host) = host_header {
|
||||
if !is_host_allowed(&server_address, &allow_hosts, &host) {
|
||||
warn!(
|
||||
"Rejected request with Host header {}, allowed values are [{}]",
|
||||
host,
|
||||
allow_hosts
|
||||
.iter()
|
||||
.map(|x| format!("{}:{}", x, server_address.port()))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
let err = WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
format!("Invalid Host header {}", host),
|
||||
);
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&err).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
};
|
||||
} else {
|
||||
warn!("Rejected request with missing Host header");
|
||||
let err = WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
"Missing Host header".to_string(),
|
||||
);
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&err).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
}
|
||||
if let Some(origin) = origin_header {
|
||||
let make_err = || {
|
||||
warn!(
|
||||
"Rejected request with Origin header {}, allowed values are [{}]",
|
||||
origin,
|
||||
allow_origins
|
||||
.iter()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",")
|
||||
);
|
||||
WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
format!("Invalid Origin header {}", origin),
|
||||
)
|
||||
};
|
||||
let origin_url = match Url::parse(&origin) {
|
||||
Ok(url) => url,
|
||||
Err(_) => {
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&make_err()).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
},
|
||||
};
|
||||
if !is_origin_allowed(&allow_origins, origin_url) {
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&make_err()).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
}
|
||||
}
|
||||
if method == Method::POST {
|
||||
// Disallow CORS-safelisted request headers
|
||||
// c.f. https://fetch.spec.whatwg.org/#cors-safelisted-request-header
|
||||
let content_type = content_type_header
|
||||
.as_ref()
|
||||
.map(|x| x.find(';').and_then(|idx| x.get(0..idx)).unwrap_or(x))
|
||||
.map(|x| x.trim())
|
||||
.map(|x| x.to_lowercase());
|
||||
match content_type.as_ref().map(|x| x.as_ref()) {
|
||||
Some("application/x-www-form-urlencoded") |
|
||||
Some("multipart/form-data") |
|
||||
Some("text/plain") => {
|
||||
warn!(
|
||||
"Rejected POST request with disallowed content type {}",
|
||||
content_type.unwrap_or_else(|| "".into())
|
||||
);
|
||||
let err = WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
"Invalid Content-Type",
|
||||
);
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&err).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
},
|
||||
Some(_) | None => {},
|
||||
}
|
||||
}
|
||||
let body = String::from_utf8(body.chunk().to_vec());
|
||||
if body.is_err() {
|
||||
let err = WebDriverError::new(
|
||||
ErrorStatus::UnknownError,
|
||||
"Request body wasn't valid UTF-8",
|
||||
);
|
||||
return warp::reply::with_status(
|
||||
serde_json::to_string(&err).unwrap(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
);
|
||||
}
|
||||
let body = body.unwrap();
|
||||
|
||||
debug!("-> {} {} {}", method, full_path.as_str(), body);
|
||||
let msg_result = WebDriverMessage::from_http(
|
||||
route.clone(),
|
||||
¶ms,
|
||||
&body,
|
||||
method == Method::POST,
|
||||
);
|
||||
|
||||
let (status, resp_body) = match msg_result {
|
||||
Ok(message) => {
|
||||
let (send_res, recv_res) = channel();
|
||||
match chan.lock() {
|
||||
Ok(ref c) => {
|
||||
let res =
|
||||
c.send(DispatchMessage::HandleWebDriver(message, send_res));
|
||||
match res {
|
||||
Ok(x) => x,
|
||||
Err(e) => panic!("Error: {:?}", e),
|
||||
}
|
||||
},
|
||||
Err(e) => panic!("Error reading response: {:?}", e),
|
||||
}
|
||||
|
||||
match recv_res.recv() {
|
||||
Ok(data) => match data {
|
||||
Ok(response) => {
|
||||
(StatusCode::OK, serde_json::to_string(&response).unwrap())
|
||||
},
|
||||
Err(e) => (
|
||||
StatusCode::from_u16(e.http_status().as_u16()).unwrap(),
|
||||
serde_json::to_string(&e).unwrap(),
|
||||
),
|
||||
},
|
||||
Err(e) => panic!("Error reading response: {:?}", e),
|
||||
}
|
||||
},
|
||||
Err(e) => (
|
||||
convert_status(e.http_status()),
|
||||
serde_json::to_string(&e).unwrap(),
|
||||
),
|
||||
};
|
||||
|
||||
debug!("<- {} {}", status, resp_body);
|
||||
warp::reply::with_status(resp_body, status)
|
||||
},
|
||||
)
|
||||
.with(warp::reply::with::header(
|
||||
http::header::CONTENT_TYPE,
|
||||
"application/json; charset=utf-8",
|
||||
))
|
||||
.with(warp::reply::with::header(
|
||||
http::header::CACHE_CONTROL,
|
||||
"no-cache",
|
||||
))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Convert from http 0.2 StatusCode to http 1.0 StatusCode
|
||||
fn convert_status(status: http02::StatusCode) -> StatusCode {
|
||||
StatusCode::from_u16(status.as_u16()).unwrap()
|
||||
}
|
||||
|
||||
/// Convert from http 0.2 Method to http 1.0 Method
|
||||
fn convert_method(method: http02::Method) -> Method {
|
||||
match method {
|
||||
http02::Method::OPTIONS => http::Method::OPTIONS,
|
||||
http02::Method::GET => http::Method::GET,
|
||||
http02::Method::POST => http::Method::POST,
|
||||
http02::Method::PUT => http::Method::PUT,
|
||||
http02::Method::DELETE => http::Method::DELETE,
|
||||
http02::Method::HEAD => http::Method::HEAD,
|
||||
http02::Method::TRACE => http::Method::TRACE,
|
||||
http02::Method::CONNECT => http::Method::CONNECT,
|
||||
http02::Method::PATCH => http::Method::PATCH,
|
||||
_ => http::Method::from_bytes(method.as_str().as_bytes()).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_host_allowed() {
|
||||
let addr_80 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80);
|
||||
let addr_8000 = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 8000);
|
||||
let addr_v6_80 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 80);
|
||||
let addr_v6_8000 = SocketAddr::new(IpAddr::from_str("::1").unwrap(), 8000);
|
||||
|
||||
// We match the host ip address to the server, so we can only use hosts that actually resolve
|
||||
let localhost_host = Host::Domain("localhost".to_string());
|
||||
let test_host = Host::Domain("example.test".to_string());
|
||||
let subdomain_localhost_host = Host::Domain("subdomain.localhost".to_string());
|
||||
|
||||
assert!(is_host_allowed(
|
||||
&addr_80,
|
||||
&[localhost_host.clone()],
|
||||
"localhost:80"
|
||||
));
|
||||
assert!(is_host_allowed(
|
||||
&addr_80,
|
||||
&[test_host.clone()],
|
||||
"example.test:80"
|
||||
));
|
||||
assert!(is_host_allowed(
|
||||
&addr_80,
|
||||
&[test_host.clone(), localhost_host.clone()],
|
||||
"example.test"
|
||||
));
|
||||
assert!(is_host_allowed(
|
||||
&addr_80,
|
||||
&[subdomain_localhost_host.clone()],
|
||||
"subdomain.localhost"
|
||||
));
|
||||
|
||||
// ip address cases
|
||||
assert!(is_host_allowed(&addr_80, &[], "127.0.0.1:80"));
|
||||
assert!(is_host_allowed(&addr_v6_80, &[], "127.0.0.1"));
|
||||
assert!(is_host_allowed(&addr_80, &[], "[::1]"));
|
||||
assert!(is_host_allowed(&addr_8000, &[], "127.0.0.1:8000"));
|
||||
assert!(is_host_allowed(
|
||||
&addr_80,
|
||||
&[subdomain_localhost_host.clone()],
|
||||
"[::1]"
|
||||
));
|
||||
assert!(is_host_allowed(
|
||||
&addr_v6_8000,
|
||||
&[subdomain_localhost_host.clone()],
|
||||
"[::1]:8000"
|
||||
));
|
||||
|
||||
// Mismatch cases
|
||||
|
||||
assert!(!is_host_allowed(&addr_80, &[test_host], "localhost"));
|
||||
|
||||
assert!(!is_host_allowed(&addr_80, &[], "localhost:80"));
|
||||
|
||||
// Port mismatch cases
|
||||
|
||||
assert!(!is_host_allowed(
|
||||
&addr_80,
|
||||
&[localhost_host.clone()],
|
||||
"localhost:8000"
|
||||
));
|
||||
assert!(!is_host_allowed(
|
||||
&addr_8000,
|
||||
&[localhost_host.clone()],
|
||||
"localhost"
|
||||
));
|
||||
assert!(!is_host_allowed(
|
||||
&addr_v6_8000,
|
||||
&[localhost_host.clone()],
|
||||
"[::1]"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_origin_allowed() {
|
||||
assert!(is_origin_allowed(
|
||||
&[Url::parse("http://localhost").unwrap()],
|
||||
Url::parse("http://localhost").unwrap()
|
||||
));
|
||||
assert!(is_origin_allowed(
|
||||
&[Url::parse("http://localhost").unwrap()],
|
||||
Url::parse("http://localhost:80").unwrap()
|
||||
));
|
||||
assert!(is_origin_allowed(
|
||||
&[
|
||||
Url::parse("https://test.example").unwrap(),
|
||||
Url::parse("http://localhost").unwrap()
|
||||
],
|
||||
Url::parse("http://localhost").unwrap()
|
||||
));
|
||||
assert!(is_origin_allowed(
|
||||
&[
|
||||
Url::parse("https://test.example").unwrap(),
|
||||
Url::parse("http://localhost").unwrap()
|
||||
],
|
||||
Url::parse("https://test.example:443").unwrap()
|
||||
));
|
||||
// Mismatch cases
|
||||
assert!(!is_origin_allowed(
|
||||
&[],
|
||||
Url::parse("http://localhost").unwrap()
|
||||
));
|
||||
assert!(!is_origin_allowed(
|
||||
&[Url::parse("http://localhost").unwrap()],
|
||||
Url::parse("http://localhost:8000").unwrap()
|
||||
));
|
||||
assert!(!is_origin_allowed(
|
||||
&[Url::parse("https://localhost").unwrap()],
|
||||
Url::parse("http://localhost").unwrap()
|
||||
));
|
||||
assert!(!is_origin_allowed(
|
||||
&[Url::parse("https://example.test").unwrap()],
|
||||
Url::parse("http://subdomain.example.test").unwrap()
|
||||
));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user