keytrans: Implement improved logic for monitor_and_search

This commit is contained in:
moiseev-signal
2026-03-12 11:00:36 -07:00
committed by GitHub
parent f7c4aceebd
commit 6c2bf65989
11 changed files with 2128 additions and 666 deletions

1
Cargo.lock generated
View File

@@ -2606,6 +2606,7 @@ dependencies = [
"curve25519-dalek",
"displaydoc",
"ed25519-dalek",
"hex",
"hmac",
"itertools 0.14.0",
"proptest",

View File

@@ -5,3 +5,5 @@ v0.88.2
- Expose `useH2ForAuthChat` remote configuration key to use HTTP/2 for AuthenticatedChatConnection's non-fronted connections.
- The `disableNagleAlgorithm` remote config flag has been removed, as the experiment has been deployed successfully.
- keytrans: Improve monitor_and_search logic to handle a wider set of scenarios (keep monitoring unchanged mappings, while falling back to search for the rest).

View File

@@ -17,6 +17,7 @@ workspace = true
curve25519-dalek = { workspace = true }
displaydoc = { workspace = true }
ed25519-dalek = { workspace = true }
hex = { workspace = true }
hmac = { workspace = true }
itertools = { workspace = true }
prost = { workspace = true }

View File

@@ -9,7 +9,9 @@ fn main() {
"src/proto/chat.proto",
];
let mut prost_build = prost_build::Config::new();
prost_build.protoc_arg("--experimental_allow_proto3_optional");
prost_build
.protoc_arg("--experimental_allow_proto3_optional")
.skip_debug(["TreeHead", "Signature"]);
prost_build
.compile_protos(&protos, &["src/proto"])
.expect("Protobufs in src are valid");

View File

@@ -16,6 +16,7 @@ mod verify;
mod vrf;
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::time::SystemTime;
pub use ed25519_dalek::VerifyingKey;
@@ -87,11 +88,39 @@ impl DeploymentMode {
}
}
impl Debug for Signature {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Signature")
.field("auditor_public_key", &hex::encode(&self.auditor_public_key))
.field("signature", &hex::encode(&self.signature))
.finish()
}
}
impl Debug for TreeHead {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TreeHead")
.field("tree_size", &self.tree_size)
.field("timestamp", &self.timestamp)
.field("signatures", &self.signatures.iter().collect::<Vec<_>>())
.finish()
}
}
pub type TreeRoot = [u8; 32];
#[derive(Debug, Clone, PartialEq)]
#[derive(Clone, PartialEq)]
pub struct LastTreeHead(pub TreeHead, pub TreeRoot);
impl Debug for LastTreeHead {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("LastTreeHead")
.field(&self.0)
.field(&hex::encode(self.1))
.finish()
}
}
impl StoredTreeHead {
pub fn into_last_tree_head(self) -> Option<LastTreeHead> {
let StoredTreeHead {
@@ -271,7 +300,7 @@ impl KeyTransparency {
/// MonitoringData is the structure retained for each key in the KT server being
/// monitored.
#[derive(Debug, Eq, PartialEq, Clone)]
#[derive(Eq, PartialEq, Clone)]
pub struct MonitoringData {
/// The VRF output on the search key.
pub index: [u8; 32],
@@ -281,6 +310,37 @@ pub struct MonitoringData {
pub ptrs: HashMap<u64, u32>,
/// Whether this client owns the key.
pub owned: bool,
/// Search key
pub search_key: Vec<u8>,
}
impl Debug for MonitoringData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
let Self {
index,
pos,
ptrs,
owned,
search_key,
} = self;
let redact_bytes = |bytes: &[u8]| {
let redacted = if let Some(last) = bytes.last_chunk::<3>() {
hex::encode(last)
} else {
"".to_owned()
};
["[REDACTED]", &redacted].join(" ...")
};
f.debug_struct("MonitoringData")
.field("index", &redact_bytes(index))
.field("pos", &pos)
.field("ptrs", &ptrs)
.field("owned", &owned)
.field("search_key", &redact_bytes(search_key))
.finish()
}
}
impl MonitoringData {
@@ -329,11 +389,19 @@ impl MonitoringData {
impl From<StoredMonitoringData> for MonitoringData {
fn from(value: StoredMonitoringData) -> Self {
let StoredMonitoringData {
index,
pos,
ptrs,
owned,
search_key,
} = value;
Self {
index: value.index.try_into().expect("must me the right size"),
pos: value.pos,
ptrs: value.ptrs,
owned: value.owned,
index: index.try_into().expect("must be the right size"),
pos,
ptrs,
owned,
search_key,
}
}
}

View File

@@ -806,6 +806,7 @@ impl MonitoringDataWrapper {
pos: zero_pos,
ptrs: HashMap::from([(ver_pos, version)]),
owned,
search_key: vec![],
});
}
}
@@ -1123,6 +1124,7 @@ mod test {
// See test_stored_account_data in rust/net/chat/src/api/keytrans.rs
ptrs: HashMap::from_iter([(16777215, 2)]),
owned: true,
search_key: vec![],
}));
// These values were obtained by running the integration test in
// rust/net/chat/src/api/keytrans.rs and extracting positions and versions
@@ -1202,6 +1204,7 @@ mod test {
pos: 10, // The search key is introduced here
ptrs: HashMap::from([(10, 1)]),
owned: true,
search_key: vec![],
}));
let steps = proof_steps([(11, 1), (15, 2)]);
@@ -1219,6 +1222,7 @@ mod test {
pos: 10, // The search key is introduced here
ptrs: HashMap::from([(10, 1)]),
owned: true,
search_key: vec![],
}));
// later position contains a smaller version
let steps = proof_steps([(11, 0)]);
@@ -1234,6 +1238,7 @@ mod test {
pos: 10, // The search key is introduced here
ptrs: HashMap::from([(10, 1)]),
owned: true,
search_key: vec![],
}));
let steps = HashMap::from_iter([
@@ -1253,6 +1258,7 @@ mod test {
pos: 10, // The search key is introduced here
ptrs: HashMap::from([(10, 1), (11, 2)]),
owned: true,
search_key: vec![],
}));
let steps = proof_steps([(11, 3)]);
let result = wrapper.update(16, &steps);

View File

@@ -3,8 +3,12 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
mod maybe_partial;
mod monitor_and_search;
mod verify_ext;
use std::borrow::Cow;
use std::collections::{BTreeSet, HashMap};
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::time::SystemTime;
@@ -15,10 +19,13 @@ use libsignal_keytrans::{
AccountData, ChatDistinguishedResponse, ChatMonitorResponse, ChatSearchResponse,
CondensedTreeSearchResponse, FullSearchResponse, FullTreeHead, KeyTransparency, LastTreeHead,
LocalStateUpdate, MonitorContext, MonitorKey, MonitorProof, MonitorRequest, MonitorResponse,
MonitoringData, SearchContext, SearchStateUpdate, SlimSearchRequest, VerifiedSearchResult,
SearchContext, SearchStateUpdate, SlimSearchRequest,
};
use libsignal_net::env::KeyTransConfig;
use libsignal_protocol::PublicKey;
pub use maybe_partial::{AccountDataField, MaybePartial};
pub use monitor_and_search::{MonitorMode, check as monitor_and_search};
use verify_ext::KeyTransparencyVerifyExt as _;
use super::RequestError;
@@ -74,18 +81,7 @@ pub(crate) struct TypedSearchResponse {
}
impl TypedSearchResponse {
pub(crate) fn from_untyped(
require_e164: bool,
require_username_hash: bool,
response: ChatSearchResponse,
) -> Result<Self, Error> {
if require_e164 != response.e164.is_some()
|| require_username_hash != response.username_hash.is_some()
{
return Err(Error::InvalidResponse(
"request/response optionality mismatch".to_string(),
));
}
pub(crate) fn from_untyped(response: ChatSearchResponse) -> Result<Self, Error> {
let ChatSearchResponse {
tree_head,
aci,
@@ -146,104 +142,6 @@ impl TypedMonitorResponse {
}
}
/// A tag identifying an optional field in [`AccountData`]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, displaydoc::Display)]
pub enum AccountDataField {
/// E.164
E164,
/// Username hash
UsernameHash,
}
/// This struct adds to its type parameter a (potentially empty) list of
/// account fields (see [`AccountDataField`]) that can no longer be verified
/// by the server.
///
/// Basically it is a non-generic version of a `Writer` monad with [`BTreeSet`] used
/// to accumulate missing field entries in some order while avoiding duplicates.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MaybePartial<T> {
pub inner: T,
pub missing_fields: BTreeSet<AccountDataField>,
}
impl<T> From<T> for MaybePartial<T> {
fn from(value: T) -> Self {
Self {
inner: value,
missing_fields: Default::default(),
}
}
}
impl<T> MaybePartial<T> {
fn new_complete(inner: T) -> Self {
Self {
inner,
missing_fields: Default::default(),
}
}
fn new(inner: T, missing_fields: impl IntoIterator<Item = AccountDataField>) -> Self {
Self {
inner,
missing_fields: BTreeSet::from_iter(missing_fields),
}
}
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> MaybePartial<U> {
MaybePartial {
inner: f(self.inner),
missing_fields: self.missing_fields,
}
}
pub fn and_then<U>(self, f: impl FnOnce(T) -> MaybePartial<U>) -> MaybePartial<U> {
let MaybePartial {
inner,
mut missing_fields,
} = self;
let MaybePartial {
inner: final_inner,
missing_fields: other_missing,
} = f(inner);
missing_fields.extend(other_missing);
MaybePartial {
inner: final_inner,
missing_fields,
}
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn into_result(self) -> Result<T, BTreeSet<AccountDataField>> {
let Self {
inner,
missing_fields,
} = self;
if missing_fields.is_empty() {
Ok(inner)
} else {
Err(missing_fields)
}
}
}
impl<T, E> MaybePartial<std::result::Result<T, E>> {
fn transpose(self) -> std::result::Result<MaybePartial<T>, E> {
let MaybePartial {
inner,
missing_fields,
} = self;
Ok(MaybePartial {
inner: inner?,
missing_fields,
})
}
}
/// Representation of an object as "search key" aligned with conversion
/// performed by the chat server.
///
@@ -255,6 +153,12 @@ pub trait SearchKey {
fn as_search_key(&self) -> Vec<u8>;
}
impl<T: SearchKey> SearchKey for &T {
fn as_search_key(&self) -> Vec<u8> {
(*self).as_search_key()
}
}
impl SearchKey for Aci {
fn as_search_key(&self) -> Vec<u8> {
[SEARCH_KEY_PREFIX_ACI, self.service_id_binary().as_slice()].concat()
@@ -349,110 +253,6 @@ pub trait UnauthenticatedChatApi {
) -> impl Future<Output = Result<AccountData, RequestError<Error>>> + Send;
}
#[derive(Eq, Debug, PartialEq, Clone, Copy)]
pub enum MonitorMode {
MonitorSelf,
MonitorOther,
}
#[derive(Debug, Clone, PartialEq)]
struct SearchVersions {
aci: Option<u32>,
e164: Option<u32>,
username_hash: Option<u32>,
}
impl SearchVersions {
fn from_account_data(account_data: &AccountData) -> Self {
let AccountData {
aci,
e164,
username_hash,
..
} = account_data;
Self {
aci: Some(aci.greatest_version()),
e164: e164.as_ref().map(|x| x.greatest_version()),
username_hash: username_hash.as_ref().map(|x| x.greatest_version()),
}
}
fn subtract(&self, other: &Self) -> Self {
fn opt_sub<T>(lhs: &T, rhs: &T, f: impl Fn(&T) -> Option<u32>) -> Option<u32> {
Some(f(lhs)? - f(rhs)?)
}
Self {
aci: opt_sub(self, other, |x| x.aci),
e164: opt_sub(self, other, |x| x.e164),
username_hash: opt_sub(self, other, |x| x.username_hash),
}
}
fn maximum_version(&self) -> Option<u32> {
[self.aci, self.e164, self.username_hash]
.into_iter()
.flatten()
.max()
}
}
pub async fn monitor_and_search(
kt: &impl UnauthenticatedChatApi,
aci: &Aci,
aci_identity_key: &PublicKey,
e164: Option<(E164, Vec<u8>)>,
username_hash: Option<UsernameHash<'_>>,
stored_account_data: AccountData,
distinguished_tree_head: &LastTreeHead,
mode: MonitorMode,
) -> Result<MaybePartial<AccountData>, RequestError<Error>> {
let updated_account_data = kt
.monitor(
aci,
e164.as_ref().map(|(e164, _)| *e164),
username_hash.clone(),
stored_account_data.clone(),
distinguished_tree_head,
)
.await?;
let stored_versions = SearchVersions::from_account_data(&stored_account_data);
let updated_versions = SearchVersions::from_account_data(&updated_account_data);
let version_delta = updated_versions.subtract(&stored_versions);
// Call to `monitor` guarantees that the optionality of E.164 and username hash data
// will match between `stored_account_data` and `updated_account_data`. Meaning, they will
// either both be Some() or both None.
let any_version_changed = version_delta.maximum_version().filter(|n| n > &0).is_some();
let final_account_data = match (mode, any_version_changed) {
(MonitorMode::MonitorSelf, true) => {
return Err(RequestError::Other(
libsignal_keytrans::Error::VerificationFailed(
"version change detected while self-monitoring".to_string(),
)
.into(),
));
}
(MonitorMode::MonitorOther, true) => {
kt.search(
aci,
aci_identity_key,
e164,
username_hash,
Some(stored_account_data),
distinguished_tree_head,
)
.await?
}
(MonitorMode::MonitorSelf, false) | (MonitorMode::MonitorOther, false) => {
updated_account_data.into()
}
};
Ok(final_account_data)
}
impl<'a> KeyTransparencyClient<'a> {
pub fn new(chat: &'a (dyn LowLevelChatApi + Sync), kt_config: KeyTransConfig) -> Self {
Self {
@@ -492,24 +292,35 @@ impl UnauthenticatedChatApi for KeyTransparencyClient<'_> {
))
})
})
.and_then(|r| {
TypedSearchResponse::from_untyped(e164.is_some(), username_hash.is_some(), r)
.map_err(RequestError::Other)
})?;
.and_then(|r| TypedSearchResponse::from_untyped(r).map_err(RequestError::Other))?;
let now = SystemTime::now();
verify_chat_search_response(
&self.inner,
aci,
e164.map(|(e164, _)| e164),
username_hash,
stored_account_data,
chat_search_response,
Some(distinguished_tree_head),
now,
)
.map_err(RequestError::Other)
let e164_search_key = e164.as_ref().map(|(x, _)| x.as_search_key());
let username_hash_search_key = username_hash.as_ref().map(SearchKey::as_search_key);
let mut account_data = self
.inner
.verify_chat_search_response(
aci,
e164.map(|(e164, _)| e164),
username_hash,
stored_account_data,
chat_search_response,
Some(distinguished_tree_head),
now,
)
.map_err(RequestError::Other)?;
// Preserve search keys in account data
account_data.inner.aci.search_key = aci.as_search_key();
if let Some(stored) = account_data.inner.e164.as_mut() {
stored.search_key = e164_search_key.unwrap_or_default();
}
if let Some(stored) = account_data.inner.username_hash.as_mut() {
stored.search_key = username_hash_search_key.unwrap_or_default();
}
Ok(account_data)
}
async fn distinguished(
@@ -685,9 +496,15 @@ impl UnauthenticatedChatApi for KeyTransparencyClient<'_> {
} = verified;
let mut take_data = move |search_key: &[u8], err_message: &'static str| {
monitoring_data.remove(search_key).ok_or_else(|| {
RequestError::Other(Error::InvalidResponse(err_message.to_string()))
})
monitoring_data
.remove(search_key)
.map(|mut d| {
d.search_key = search_key.to_vec();
d
})
.ok_or_else(|| {
RequestError::Other(Error::InvalidResponse(err_message.to_string()))
})
};
AccountData {
@@ -713,186 +530,12 @@ impl UnauthenticatedChatApi for KeyTransparencyClient<'_> {
}
}
fn verify_single_search_response(
kt: &KeyTransparency,
search_key: Vec<u8>,
response: CondensedTreeSearchResponse,
monitoring_data: Option<MonitoringData>,
full_tree_head: &FullTreeHead,
last_tree_head: Option<&LastTreeHead>,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<VerifiedSearchResult, Error> {
let result = kt.verify_search(
SlimSearchRequest {
search_key,
version: None,
},
FullSearchResponse::new(response, full_tree_head),
SearchContext {
last_tree_head,
last_distinguished_tree_head,
data: monitoring_data,
},
true,
now,
)?;
Ok(result)
}
fn verify_chat_search_response(
kt: &KeyTransparency,
aci: &Aci,
e164: Option<E164>,
username_hash: Option<UsernameHash>,
stored_account_data: Option<AccountData>,
chat_search_response: TypedSearchResponse,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<MaybePartial<AccountData>, Error> {
let TypedSearchResponse {
full_tree_head,
aci_search_response,
e164_search_response,
username_hash_search_response,
} = chat_search_response;
let (
aci_monitoring_data,
e164_monitoring_data,
username_hash_monitoring_data,
stored_last_tree_head,
) = match stored_account_data {
None => (None, None, None, None),
Some(acc) => {
let AccountData {
aci,
e164,
username_hash,
last_tree_head,
} = acc;
(Some(aci), e164, username_hash, Some(last_tree_head))
}
};
let aci_result = verify_single_search_response(
kt,
aci.as_search_key(),
aci_search_response,
aci_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)?;
let e164_result = match_optional_fields(e164, e164_search_response, AccountDataField::E164)?
.map(|non_partial| {
non_partial
.map(|(e164, e164_search_response)| {
verify_single_search_response(
kt,
e164.as_search_key(),
e164_search_response,
e164_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)
})
.transpose()
})
.transpose()?;
let username_hash_result = match_optional_fields(
username_hash,
username_hash_search_response,
AccountDataField::UsernameHash,
)?
.map(|non_partial| {
non_partial
.map(|(username_hash, username_hash_response)| {
verify_single_search_response(
kt,
username_hash.as_search_key(),
username_hash_response,
username_hash_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)
})
.transpose()
})
.transpose()?;
let MaybePartial {
inner: (e164_result, username_hash_result),
missing_fields,
} = e164_result.and_then(|e164| username_hash_result.map(|hash| (e164, hash)));
if !aci_result.are_all_roots_equal([e164_result.as_ref(), username_hash_result.as_ref()]) {
return Err(Error::InvalidResponse("mismatching tree roots".to_string()));
}
// ACI response is guaranteed to be present, taking the last tree head from it.
let LocalStateUpdate {
tree_head,
tree_root,
monitoring_data: updated_aci_monitoring_data,
} = aci_result.state_update;
let updated_account_data = AccountData {
aci: updated_aci_monitoring_data
.ok_or_else(|| Error::InvalidResponse("ACI data is missing".to_string()))?,
e164: e164_result.and_then(|r| r.state_update.monitoring_data),
username_hash: username_hash_result.and_then(|r| r.state_update.monitoring_data),
last_tree_head: LastTreeHead(tree_head, tree_root),
};
Ok(MaybePartial {
inner: updated_account_data,
missing_fields,
})
}
/// This function tries to match the optional value in request and response.
///
/// The rules of matching are:
/// - If neither `request_value` nor `response_value` is present, the result is
/// considered complete (in `MaybePartial` terms) and will require no further
/// handling. It is expected to not have a value in the response if it had
/// never been requested to start with.
/// - If both `request_value` and `response_value` are present, the result is
/// considered complete and ready for further verification.
/// - If `response_value` is present but `request_value` is not, there is
/// something wrong with the server implementation. We never requested the
/// field, but the response contains a corresponding value.
/// - If `request_value` is present but `response_value` isn't we consider the
/// response complete but not suitable for further processing and record a
/// missing field inside `MaybePartial`.
fn match_optional_fields<T, U>(
request_value: Option<T>,
response_value: Option<U>,
field: AccountDataField,
) -> Result<MaybePartial<Option<(T, U)>>, Error> {
match (request_value, response_value) {
(Some(a), Some(b)) => Ok(MaybePartial::new_complete(Some((a, b)))),
(None, None) => Ok(MaybePartial::new_complete(None)),
(None, Some(_)) => Err(Error::InvalidResponse(format!(
"Unexpected field in the response: {}",
&field
))),
(Some(_), None) => Ok(MaybePartial::new(None, vec![field])),
}
}
#[cfg(test)]
pub(crate) mod test_support {
use std::cell::Cell;
use std::time::Duration;
use assert_matches::assert_matches;
use const_str::hex;
use libsignal_keytrans::{StoredAccountData, TreeHead};
use libsignal_net::env;
@@ -997,13 +640,118 @@ pub(crate) mod test_support {
pub fn test_account_data() -> AccountData {
AccountData::try_from(test_stored_account_data()).expect("valid account data")
}
#[derive(Debug, Clone)]
pub struct OwnedParameters {
pub aci: Aci,
pub e164: Option<(E164, Vec<u8>)>,
pub username_hash_bytes: Option<Vec<u8>>,
}
#[derive(Default)]
pub struct SearchStub {
pub result: Option<Result<MaybePartial<AccountData>, RequestError<Error>>>,
pub invocations: Vec<OwnedParameters>,
}
impl SearchStub {
fn new(res: Option<Result<MaybePartial<AccountData>, RequestError<Error>>>) -> Self {
Self {
result: res,
invocations: vec![],
}
}
}
pub struct TestKt {
pub monitor: Cell<Option<Result<AccountData, RequestError<Error>>>>,
pub search: Cell<SearchStub>,
}
impl TestKt {
pub fn for_monitor(monitor: Result<AccountData, RequestError<Error>>) -> Self {
Self::new(Some(monitor), None)
}
pub fn for_search(search: Result<MaybePartial<AccountData>, RequestError<Error>>) -> Self {
Self::new(None, Some(search))
}
pub fn new(
monitor: Option<Result<AccountData, RequestError<Error>>>,
search: Option<Result<MaybePartial<AccountData>, RequestError<Error>>>,
) -> Self {
Self {
monitor: Cell::new(monitor),
search: Cell::new(SearchStub::new(search)),
}
}
pub fn expected_error() -> RequestError<Error> {
RequestError::Unexpected {
log_safe: "test error".to_string(),
}
}
pub fn assert_expected_error<T: Debug>(result: Result<T, RequestError<Error>>) {
assert_matches!(result, Err(RequestError::Unexpected { log_safe }) if log_safe == "test error")
}
}
impl UnauthenticatedChatApi for TestKt {
fn search(
&self,
aci: &Aci,
_aci_identity_key: &PublicKey,
e164: Option<(E164, Vec<u8>)>,
username_hash: Option<UsernameHash<'_>>,
_stored_account_data: Option<AccountData>,
_distinguished_tree_head: &LastTreeHead,
) -> impl Future<Output = Result<MaybePartial<AccountData>, RequestError<Error>>> + Send
{
let mut search_stub = self.search.take();
search_stub.invocations.push(OwnedParameters {
aci: *aci,
e164,
username_hash_bytes: username_hash.map(|x| x.as_ref().to_vec()),
});
let result = search_stub
.result
.as_ref()
.expect("unexpected call to search")
.clone();
self.search.set(search_stub);
std::future::ready(result.clone())
}
fn distinguished(
&self,
_: Option<LastTreeHead>,
) -> impl Future<Output = Result<SearchStateUpdate, RequestError<Error>>> {
// not used in the tests
unreachable!();
#[allow(unreachable_code)] // without this, `impl Future` gets confused
std::future::pending()
}
fn monitor(
&self,
_aci: &Aci,
_e164: Option<E164>,
_username_hash: Option<UsernameHash<'_>>,
_account_data: AccountData,
_last_distinguished_tree_head: &LastTreeHead,
) -> impl Future<Output = Result<AccountData, RequestError<Error>>> + Send {
let result = self.monitor.take().expect("unexpected call to monitor");
std::future::ready(result)
}
}
}
#[cfg(test)]
mod test {
use std::cell::Cell;
use std::sync::{Arc, Mutex};
use assert_matches::assert_matches;
use prost::Message as _;
use test_case::test_case;
@@ -1018,7 +766,7 @@ mod test {
let chat_search_response =
libsignal_keytrans::ChatSearchResponse::decode(CHAT_SEARCH_RESPONSE)
.expect("valid response");
TypedSearchResponse::from_untyped(true, true, chat_search_response)
TypedSearchResponse::from_untyped(chat_search_response)
.expect("valid typed search response")
}
@@ -1049,8 +797,7 @@ mod test {
config: KEYTRANS_CONFIG_STAGING.into(),
};
let result = verify_chat_search_response(
&kt,
let result = kt.verify_chat_search_response(
&aci,
e164,
username_hash,
@@ -1091,8 +838,7 @@ mod test {
config: KEYTRANS_CONFIG_STAGING.into(),
};
let result = verify_chat_search_response(
&kt,
let result = kt.verify_chat_search_response(
&aci,
Some(e164),
Some(username_hash),
@@ -1106,244 +852,4 @@ mod test {
assert_eq!(skip.to_vec(), missing_fields.into_iter().collect::<Vec<_>>())
);
}
#[derive(Default)]
struct SearchStub {
result: Option<Result<MaybePartial<AccountData>, RequestError<Error>>>,
}
impl SearchStub {
fn new(res: Result<MaybePartial<AccountData>, RequestError<Error>>) -> Self {
Self { result: Some(res) }
}
}
struct TestKt {
monitor: Cell<Option<Result<AccountData, RequestError<Error>>>>,
search: Arc<Mutex<SearchStub>>,
}
impl TestKt {
fn for_monitor(monitor: Result<AccountData, RequestError<Error>>) -> Self {
Self {
monitor: Cell::new(Some(monitor)),
search: Arc::new(Mutex::new(Default::default())),
}
}
fn new(
monitor: Result<AccountData, RequestError<Error>>,
search: Result<MaybePartial<AccountData>, RequestError<Error>>,
) -> Self {
Self {
monitor: Cell::new(Some(monitor)),
search: Arc::new(Mutex::new(SearchStub::new(search))),
}
}
}
impl UnauthenticatedChatApi for TestKt {
fn search(
&self,
_aci: &Aci,
_aci_identity_key: &PublicKey,
_e164: Option<(E164, Vec<u8>)>,
_username_hash: Option<UsernameHash<'_>>,
_stored_account_data: Option<AccountData>,
_distinguished_tree_head: &LastTreeHead,
) -> impl Future<Output = Result<MaybePartial<AccountData>, RequestError<Error>>> + Send
{
let guard = self.search.lock().expect("can lock");
if let Some(result) = guard.result.as_ref() {
std::future::ready(result.clone())
} else {
panic!("unexpected call to search")
}
}
fn distinguished(
&self,
_: Option<LastTreeHead>,
) -> impl Future<Output = Result<SearchStateUpdate, RequestError<Error>>> {
// not used in the tests
unreachable!();
#[allow(unreachable_code)] // without this, `impl Future` gets confused
std::future::pending()
}
fn monitor(
&self,
_aci: &Aci,
_e164: Option<E164>,
_username_hash: Option<UsernameHash<'_>>,
_account_data: AccountData,
_last_distinguished_tree_head: &LastTreeHead,
) -> impl Future<Output = Result<AccountData, RequestError<Error>>> + Send {
let result = self.monitor.take().expect("unexpected call to monitor");
std::future::ready(result)
}
}
#[tokio::test]
async fn monitor_and_search_monitor_error_is_returned() {
let kt = TestKt::for_monitor(Err(RequestError::Unexpected {
log_safe: "pass through unexpected error".to_owned(),
}));
let result = monitor_and_search(
&kt,
&test_account::aci(),
&test_account::aci_identity_key(),
None,
None,
test_account_data(),
&test_distinguished_tree(),
MonitorMode::MonitorOther,
)
.await;
assert_matches!(
result,
Err(RequestError::Unexpected { log_safe: msg }) if msg == "pass through unexpected error"
);
}
#[tokio::test]
async fn monitor_and_search_no_search_needed() {
let monitor_result = test_account_data();
// TestKt constructed like this will panic if search is invoked
let kt = TestKt::for_monitor(Ok(monitor_result.clone()));
let actual = monitor_and_search(
&kt,
&test_account::aci(),
&test_account::aci_identity_key(),
None,
None,
test_account_data(),
&test_distinguished_tree(),
MonitorMode::MonitorOther,
)
.await
.expect("monitor should succeed");
assert_eq!(actual, monitor_result.into());
}
#[derive(Clone, Copy)]
enum BumpVersionFor {
Aci,
E164,
UsernameHash,
}
trait Bumpable {
fn apply(&mut self, bump: BumpVersionFor);
}
impl Bumpable for AccountData {
fn apply(&mut self, bump: BumpVersionFor) {
let subject = match bump {
BumpVersionFor::Aci => &mut self.aci,
BumpVersionFor::E164 => self.e164.as_mut().unwrap(),
BumpVersionFor::UsernameHash => self.username_hash.as_mut().unwrap(),
};
// inserting a newer version of the subject
let max_version = subject.greatest_version();
subject.ptrs.insert(u64::MAX, max_version + 1);
}
}
impl Bumpable for SearchVersions {
fn apply(&mut self, bump: BumpVersionFor) {
let subject = match bump {
BumpVersionFor::Aci => &mut self.aci,
BumpVersionFor::E164 => &mut self.e164,
BumpVersionFor::UsernameHash => &mut self.username_hash,
};
if let Some(v) = subject.as_mut() {
*v += 1;
} else {
*subject = Some(1);
}
}
}
#[tokio::test]
#[test_case(BumpVersionFor::Aci; "newer Aci")]
#[test_case(BumpVersionFor::E164; "newer E.164")]
#[test_case(BumpVersionFor::UsernameHash; "newer username hash")]
async fn monitor_and_search_e164_changed(bump: BumpVersionFor) {
let mut monitor_result = test_account_data();
monitor_result.apply(bump);
let kt = TestKt::new(
Ok(monitor_result.clone()),
Err(RequestError::Unexpected {
log_safe: "pass through unexpected error".to_owned(),
}),
);
let result = monitor_and_search(
&kt,
&test_account::aci(),
&test_account::aci_identity_key(),
None,
None,
test_account_data(),
&test_distinguished_tree(),
MonitorMode::MonitorOther,
)
.await;
// monitor invocation should have succeeded, and search
// should have been invoked returning our custom error
assert_matches!(
result,
Err(RequestError::Unexpected { log_safe: msg }) if msg == "pass through unexpected error"
);
}
#[tokio::test]
async fn monitor_and_search_search_success() {
let mut monitor_result = test_account_data();
// inserting a newer version of the username hash
let max_version = monitor_result
.username_hash
.as_ref()
.unwrap()
.greatest_version();
monitor_result
.username_hash
.as_mut()
.unwrap()
.ptrs
.insert(u64::MAX, max_version + 1);
let mut search_result_account_data = test_account_data();
// make some unique change to validate this is the one that gets returned
search_result_account_data.last_tree_head.1 = [42; 32];
let kt = TestKt::new(
Ok(monitor_result.clone()),
Ok(search_result_account_data.clone().into()),
);
let updated_account_data = monitor_and_search(
&kt,
&test_account::aci(),
&test_account::aci_identity_key(),
None,
None,
test_account_data(),
&test_distinguished_tree(),
MonitorMode::MonitorOther,
)
.await
.expect("both monitor and search should have succeeded");
assert_eq!(
search_result_account_data,
updated_account_data.into_inner()
);
}
}

View File

@@ -0,0 +1,107 @@
//
// Copyright 2026 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use std::collections::BTreeSet;
/// A tag identifying an optional field in [`libsignal_keytrans::AccountData`]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, displaydoc::Display)]
pub enum AccountDataField {
/// E.164
E164,
/// Username hash
UsernameHash,
}
/// This struct adds to its type parameter a (potentially empty) list of
/// account fields (see [`AccountDataField`]) that can no longer be verified
/// by the server.
///
/// Basically it is a non-generic version of a `Writer` monad with [`BTreeSet`] used
/// to accumulate missing field entries in some order while avoiding duplicates.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MaybePartial<T> {
pub inner: T,
pub missing_fields: BTreeSet<AccountDataField>,
}
impl<T> From<T> for MaybePartial<T> {
fn from(value: T) -> Self {
Self {
inner: value,
missing_fields: Default::default(),
}
}
}
impl<T> MaybePartial<T> {
pub(super) fn new_complete(inner: T) -> Self {
Self {
inner,
missing_fields: Default::default(),
}
}
pub(super) fn new(
inner: T,
missing_fields: impl IntoIterator<Item = AccountDataField>,
) -> Self {
Self {
inner,
missing_fields: BTreeSet::from_iter(missing_fields),
}
}
pub fn map<U>(self, f: impl FnOnce(T) -> U) -> MaybePartial<U> {
MaybePartial {
inner: f(self.inner),
missing_fields: self.missing_fields,
}
}
pub fn and_then<U>(self, f: impl FnOnce(T) -> MaybePartial<U>) -> MaybePartial<U> {
let MaybePartial {
inner,
mut missing_fields,
} = self;
let MaybePartial {
inner: final_inner,
missing_fields: other_missing,
} = f(inner);
missing_fields.extend(other_missing);
MaybePartial {
inner: final_inner,
missing_fields,
}
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn into_result(self) -> Result<T, BTreeSet<AccountDataField>> {
let Self {
inner,
missing_fields,
} = self;
if missing_fields.is_empty() {
Ok(inner)
} else {
Err(missing_fields)
}
}
}
impl<T, E> MaybePartial<Result<T, E>> {
pub fn transpose(self) -> Result<MaybePartial<T>, E> {
let MaybePartial {
inner,
missing_fields,
} = self;
Ok(MaybePartial {
inner: inner?,
missing_fields,
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,215 @@
//
// Copyright 2026 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use std::time::SystemTime;
use libsignal_core::{Aci, E164};
use libsignal_keytrans::{
AccountData, CondensedTreeSearchResponse, FullSearchResponse, FullTreeHead, KeyTransparency,
LastTreeHead, LocalStateUpdate, MonitoringData, SearchContext, SlimSearchRequest,
VerifiedSearchResult,
};
use super::{AccountDataField, Error, MaybePartial, SearchKey, TypedSearchResponse, UsernameHash};
pub(super) trait KeyTransparencyVerifyExt {
fn verify_single_search_response(
&self,
search_key: Vec<u8>,
response: CondensedTreeSearchResponse,
monitoring_data: Option<MonitoringData>,
full_tree_head: &FullTreeHead,
last_tree_head: Option<&LastTreeHead>,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<VerifiedSearchResult, Error>;
fn verify_chat_search_response(
&self,
aci: &Aci,
e164: Option<E164>,
username_hash: Option<UsernameHash>,
stored_account_data: Option<AccountData>,
chat_search_response: TypedSearchResponse,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<MaybePartial<AccountData>, Error>;
}
impl KeyTransparencyVerifyExt for KeyTransparency {
fn verify_single_search_response(
&self,
search_key: Vec<u8>,
response: CondensedTreeSearchResponse,
monitoring_data: Option<MonitoringData>,
full_tree_head: &FullTreeHead,
last_tree_head: Option<&LastTreeHead>,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<VerifiedSearchResult, Error> {
let result = self.verify_search(
SlimSearchRequest {
search_key,
version: None,
},
FullSearchResponse::new(response, full_tree_head),
SearchContext {
last_tree_head,
last_distinguished_tree_head,
data: monitoring_data,
},
true,
now,
)?;
Ok(result)
}
fn verify_chat_search_response(
&self,
aci: &Aci,
e164: Option<E164>,
username_hash: Option<UsernameHash>,
stored_account_data: Option<AccountData>,
chat_search_response: TypedSearchResponse,
last_distinguished_tree_head: Option<&LastTreeHead>,
now: SystemTime,
) -> Result<MaybePartial<AccountData>, Error> {
let TypedSearchResponse {
full_tree_head,
aci_search_response,
e164_search_response,
username_hash_search_response,
} = chat_search_response;
let (
aci_monitoring_data,
e164_monitoring_data,
username_hash_monitoring_data,
stored_last_tree_head,
) = match stored_account_data {
None => (None, None, None, None),
Some(acc) => {
let AccountData {
aci,
e164,
username_hash,
last_tree_head,
} = acc;
(Some(aci), e164, username_hash, Some(last_tree_head))
}
};
let aci_result = self.verify_single_search_response(
aci.as_search_key(),
aci_search_response,
aci_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)?;
let e164_result =
match_optional_fields(e164, e164_search_response, AccountDataField::E164)?
.map(|non_partial| {
non_partial
.map(|(e164, e164_search_response)| {
self.verify_single_search_response(
e164.as_search_key(),
e164_search_response,
e164_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)
})
.transpose()
})
.transpose()?;
let username_hash_result = match_optional_fields(
username_hash,
username_hash_search_response,
AccountDataField::UsernameHash,
)?
.map(|non_partial| {
non_partial
.map(|(username_hash, username_hash_response)| {
self.verify_single_search_response(
username_hash.as_search_key(),
username_hash_response,
username_hash_monitoring_data,
&full_tree_head,
stored_last_tree_head.as_ref(),
last_distinguished_tree_head,
now,
)
})
.transpose()
})
.transpose()?;
let MaybePartial {
inner: (e164_result, username_hash_result),
missing_fields,
} = e164_result.and_then(|e164| username_hash_result.map(|hash| (e164, hash)));
if !aci_result.are_all_roots_equal([e164_result.as_ref(), username_hash_result.as_ref()]) {
return Err(Error::InvalidResponse("mismatching tree roots".to_string()));
}
// ACI response is guaranteed to be present, taking the last tree head from it.
let LocalStateUpdate {
tree_head,
tree_root,
monitoring_data: updated_aci_monitoring_data,
} = aci_result.state_update;
let updated_account_data = AccountData {
aci: updated_aci_monitoring_data
.ok_or_else(|| Error::InvalidResponse("ACI data is missing".to_string()))?,
e164: e164_result.and_then(|r| r.state_update.monitoring_data),
username_hash: username_hash_result.and_then(|r| r.state_update.monitoring_data),
last_tree_head: LastTreeHead(tree_head, tree_root),
};
Ok(MaybePartial {
inner: updated_account_data,
missing_fields,
})
}
}
/// This function tries to match the optional value in request and response.
///
/// The rules of matching are:
/// - If neither `request_value` nor `response_value` is present, the result is
/// considered complete (in `MaybePartial` terms) and will require no further
/// handling. It is expected to not have a value in the response if it had
/// never been requested to start with.
/// - If both `request_value` and `response_value` are present, the result is
/// considered complete and ready for further verification.
/// - If `response_value` is present but `request_value` is not, there is
/// something wrong with the server implementation. We never requested the
/// field, but the response contains a corresponding value.
/// - If `request_value` is present but `response_value` isn't we consider the
/// response complete but not suitable for further processing and record a
/// missing field inside `MaybePartial`.
fn match_optional_fields<T, U>(
request_value: Option<T>,
response_value: Option<U>,
field: AccountDataField,
) -> Result<MaybePartial<Option<(T, U)>>, Error> {
match (request_value, response_value) {
(Some(a), Some(b)) => Ok(MaybePartial::new_complete(Some((a, b)))),
(None, None) => Ok(MaybePartial::new_complete(None)),
(None, Some(_)) => Err(Error::InvalidResponse(format!(
"Unexpected field in the response: {}",
&field
))),
(Some(_), None) => Ok(MaybePartial::new(None, vec![field])),
}
}

View File

@@ -496,7 +496,7 @@ mod test_support {
{
let search_response = ChatSearchResponse::decode(response_bytes.as_ref())
.map_err(|_| Error::InvalidResponse("bad protobuf".to_string()))
.and_then(|r| TypedSearchResponse::from_untyped(true, true, r))
.and_then(TypedSearchResponse::from_untyped)
.expect("valid search response");
let tree_size = search_response.full_tree_head.tree_head.unwrap().tree_size;