// // Copyright 2020 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // // Different parts of this module are used in different tests/benchmarks, therefore some of the // APIs will always be considered dead code. #![allow(dead_code)] use std::ops::RangeFrom; use std::time::SystemTime; use futures_util::FutureExt; use libsignal_protocol::*; use rand::rngs::OsRng; use rand::{CryptoRng, Rng, TryRngCore as _}; // Deliberately not reusing the constants from `protocol`. pub(crate) const PRE_KYBER_MESSAGE_VERSION: u32 = 3; pub(crate) const KYBER_AWARE_MESSAGE_VERSION: u32 = 4; pub fn test_in_memory_protocol_store() -> Result { let mut csprng = OsRng.unwrap_err(); let identity_key = IdentityKeyPair::generate(&mut csprng); // Valid registration IDs fit in 14 bits. let registration_id: u8 = csprng.random(); InMemSignalProtocolStore::new(identity_key, registration_id as u32) } pub async fn encrypt( store: &mut InMemSignalProtocolStore, remote_address: &ProtocolAddress, local_address: &ProtocolAddress, msg: &str, ) -> Result { let mut csprng = OsRng.unwrap_err(); message_encrypt( msg.as_bytes(), remote_address, local_address, &mut store.session_store, &mut store.identity_store, SystemTime::now(), &mut csprng, ) .await } pub async fn decrypt( store: &mut InMemSignalProtocolStore, remote_address: &ProtocolAddress, local_address: &ProtocolAddress, msg: &CiphertextMessage, ) -> Result, SignalProtocolError> { let mut csprng = OsRng.unwrap_err(); message_decrypt( msg, remote_address, local_address, &mut store.session_store, &mut store.identity_store, &mut store.pre_key_store, &store.signed_pre_key_store, &mut store.kyber_pre_key_store, &mut csprng, ) .await } pub async fn create_pre_key_bundle( store: &mut dyn ProtocolStore, mut csprng: &mut R, ) -> Result { let pre_key_pair = KeyPair::generate(&mut csprng); let signed_pre_key_pair = KeyPair::generate(&mut csprng); let kyber_pre_key_pair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut csprng); let signed_pre_key_public = signed_pre_key_pair.public_key.serialize(); let signed_pre_key_signature = store .get_identity_key_pair() .await? .private_key() .calculate_signature(&signed_pre_key_public, &mut csprng)?; let kyber_pre_key_public = kyber_pre_key_pair.public_key.serialize(); let kyber_pre_key_signature = store .get_identity_key_pair() .await? .private_key() .calculate_signature(&kyber_pre_key_public, &mut csprng)?; let device_id: DeviceId = csprng.random(); let pre_key_id: u32 = csprng.random(); let signed_pre_key_id: u32 = csprng.random(); let kyber_pre_key_id: u32 = csprng.random(); let pre_key_bundle = PreKeyBundle::new( store.get_local_registration_id().await?, device_id, Some((pre_key_id.into(), pre_key_pair.public_key)), signed_pre_key_id.into(), signed_pre_key_pair.public_key, signed_pre_key_signature.to_vec(), kyber_pre_key_id.into(), kyber_pre_key_pair.public_key.clone(), kyber_pre_key_signature.to_vec(), *store.get_identity_key_pair().await?.identity_key(), )?; store .save_pre_key( pre_key_id.into(), &PreKeyRecord::new(pre_key_id.into(), &pre_key_pair), ) .await?; let timestamp = Timestamp::from_epoch_millis(csprng.random()); store .save_signed_pre_key( signed_pre_key_id.into(), &SignedPreKeyRecord::new( signed_pre_key_id.into(), timestamp, &signed_pre_key_pair, &signed_pre_key_signature, ), ) .await?; store .save_kyber_pre_key( kyber_pre_key_id.into(), &KyberPreKeyRecord::new( kyber_pre_key_id.into(), Timestamp::from_epoch_millis(43), &kyber_pre_key_pair, &kyber_pre_key_signature, ), ) .await?; Ok(pre_key_bundle) } pub fn initialize_sessions_v4() -> Result<(SessionRecord, SessionRecord), SignalProtocolError> { let mut csprng = OsRng.unwrap_err(); let alice_identity = IdentityKeyPair::generate(&mut csprng); let bob_identity = IdentityKeyPair::generate(&mut csprng); let alice_base_key = KeyPair::generate(&mut csprng); let bob_base_key = KeyPair::generate(&mut csprng); let bob_ephemeral_key = bob_base_key; let bob_kyber_key = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut csprng); let alice_params = AliceSignalProtocolParameters::new( alice_identity, alice_base_key, *bob_identity.identity_key(), bob_base_key.public_key, bob_ephemeral_key.public_key, bob_kyber_key.public_key.clone(), ); let alice_session = initialize_alice_session_record(&alice_params, &mut csprng)?; let kyber_ciphertext = { let bytes = alice_session .get_kyber_ciphertext()? .expect("has kyber ciphertext") .clone(); bytes.into_boxed_slice() }; let bob_params = BobSignalProtocolParameters::new( bob_identity, bob_base_key, None, bob_ephemeral_key, bob_kyber_key, *alice_identity.identity_key(), alice_base_key.public_key, &kyber_ciphertext, ); let bob_session = initialize_bob_session_record(&bob_params)?; Ok((alice_session, bob_session)) } pub fn extract_single_ssv2_received_message(input: &[u8]) -> (ServiceId, Vec) { let message = SealedSenderV2SentMessage::parse(input).expect("valid"); assert_eq!(1, message.recipients.len()); let (service_id, recipient) = message.recipients.first().expect("checked length"); let result = message .received_message_parts_for_recipient(recipient) .as_ref() .concat(); (*service_id, result) } pub enum IdChoice { Exactly(u32), Next, Random, } impl From for IdChoice { fn from(id: u32) -> Self { IdChoice::Exactly(id) } } pub struct TestStoreBuilder { rng: rand_core::UnwrapErr, pub(crate) store: InMemSignalProtocolStore, id_range: RangeFrom, } impl TestStoreBuilder { pub fn new() -> Self { let mut rng = OsRng.unwrap_err(); let identity_key = IdentityKeyPair::generate(&mut rng); // Valid registration IDs fit in 14 bits. let registration_id: u8 = rng.random(); let store = InMemSignalProtocolStore::new(identity_key, registration_id as u32) .expect("can create store"); Self { rng, store, id_range: 0.., } } pub fn from_store(store: &InMemSignalProtocolStore) -> Self { Self { rng: OsRng.unwrap_err(), store: store.clone(), id_range: 0.., } } pub fn with_pre_key(mut self, id_choice: IdChoice) -> Self { self.add_pre_key(id_choice); self } pub fn add_pre_key(&mut self, id_choice: IdChoice) { let id = self.gen_id(id_choice); // TODO: this requirement can be removed if store returns ids in the insertion order if let Some(latest_id) = self.store.all_pre_key_ids().last() { assert!(id > (*latest_id).into(), "Pre key ids should be increasing"); } let pair = KeyPair::generate(&mut self.rng); self.store .save_pre_key(id.into(), &PreKeyRecord::new(id.into(), &pair)) .now_or_never() .expect("sync") .expect("able to store pre key"); } pub fn with_signed_pre_key(mut self, id_choice: IdChoice) -> Self { self.add_signed_pre_key(id_choice); self } pub fn add_signed_pre_key(&mut self, id_choice: IdChoice) { let id = self.gen_id(id_choice); if let Some(latest_id) = self.store.all_signed_pre_key_ids().last() { assert!( id > (*latest_id).into(), "Signed pre key ids should be increasing" ); } let pair = KeyPair::generate(&mut self.rng); let public = pair.public_key.serialize(); let signature = self.sign(&public); let record = SignedPreKeyRecord::new( id.into(), Timestamp::from_epoch_millis(42), &pair, &signature, ); self.store .save_signed_pre_key(id.into(), &record) .now_or_never() .expect("sync") .expect("able to store signed pre key"); } pub fn with_kyber_pre_key(mut self, id_choice: IdChoice) -> Self { self.add_kyber_pre_key(id_choice); self } pub fn add_kyber_pre_key(&mut self, id_choice: IdChoice) { let id = self.gen_id(id_choice); if let Some(latest_id) = self.store.all_kyber_pre_key_ids().last() { assert!( id > (*latest_id).into(), "Signed pre key ids should be increasing" ); } let pair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut self.rng); let public = pair.public_key.serialize(); let signature = self.sign(&public); let record = KyberPreKeyRecord::new( id.into(), Timestamp::from_epoch_millis(43), &pair, &signature, ); self.store .save_kyber_pre_key(id.into(), &record) .now_or_never() .expect("sync") .expect("able toe store kyber pre key"); } pub fn make_bundle_with_latest_keys(&self, device_id: DeviceId) -> PreKeyBundle { let registration_id = self .store .get_local_registration_id() .now_or_never() .expect("sync") .expect("contains local registration id"); let maybe_pre_key_record = self.store.all_pre_key_ids().max().map(|id| { self.store .pre_key_store .get_pre_key(*id) .now_or_never() .expect("syng") .expect("has pre key") }); let identity_key_pair = self .store .get_identity_key_pair() .now_or_never() .expect("sync") .expect("has identity key pair"); let identity_key = identity_key_pair.identity_key(); let signed_pre_key_record = self .store .all_signed_pre_key_ids() .max() .map(|id| { self.store .get_signed_pre_key(*id) .now_or_never() .expect("sync") .expect("has signed pre key") }) .expect("contains at least one signed pre key"); let kyber_pre_key_record = self .store .all_kyber_pre_key_ids() .max() .map(|id| { self.store .get_kyber_pre_key(*id) .now_or_never() .expect("sync") .expect("has kyber pre key") }) .expect("contains at least one kyber key"); PreKeyBundle::new( registration_id, device_id, maybe_pre_key_record.map(|rec| { ( rec.id().expect("has id"), rec.public_key().expect("has public key"), ) }), signed_pre_key_record.id().expect("has id"), signed_pre_key_record.public_key().expect("has public key"), signed_pre_key_record.signature().expect("has signature"), kyber_pre_key_record.id().expect("has id"), kyber_pre_key_record.public_key().expect("has public key"), kyber_pre_key_record.signature().expect("has signature"), *identity_key, ) .expect("can make pre key bundle from store") } fn sign(&mut self, message: &[u8]) -> Box<[u8]> { let identity_key_pair = self .store .get_identity_key_pair() .now_or_never() .expect("sync") .expect("able to get identity"); let signing_key = identity_key_pair.private_key(); signing_key .calculate_signature(message, &mut self.rng) .expect("able to sign with identity key") } fn next_id(&mut self) -> u32 { self.id_range.next().expect("should have enough ids") } fn gen_id(&mut self, choice: IdChoice) -> u32 { match choice { IdChoice::Exactly(id) => id, // TODO: check the maximal existing id and continue from it IdChoice::Next => self.next_id(), IdChoice::Random => self.rng.random(), } } } pub trait HasSessionVersion { fn session_version(&self, address: &ProtocolAddress) -> Result; } impl HasSessionVersion for TestStoreBuilder { fn session_version(&self, address: &ProtocolAddress) -> Result { self.store.session_version(address) } } impl HasSessionVersion for InMemSignalProtocolStore { fn session_version(&self, address: &ProtocolAddress) -> Result { self.load_session(address) .now_or_never() .expect("sync")? .expect("session found") .session_version() } }