Add PQXDH tests

This commit is contained in:
moiseev-signal
2023-05-23 16:14:44 -07:00
committed by GitHub
parent b0a1bf2bd6
commit 28e112bac1
11 changed files with 2215 additions and 1702 deletions

View File

@@ -3,27 +3,30 @@
// SPDX-License-Identifier: AGPL-3.0-only
//
// Different parts of this module are used in different tests/benchmarks, therefore some of the
// APIs will always be considered dead code.
#![allow(dead_code)]
use futures_util::FutureExt;
use libsignal_protocol::*;
use rand::rngs::OsRng;
use rand::{CryptoRng, Rng};
use std::ops::RangeFrom;
// Deliberately not reusing the constants from `protocol`.
// dead_code warning reported when building benches.
#[allow(dead_code)]
pub(crate) const PRE_KYBER_MESSAGE_VERSION: u32 = 3;
#[allow(dead_code)]
pub(crate) const KYBER_AWARE_MESSAGE_VERSION: u32 = 4;
#[allow(dead_code)]
pub fn test_in_memory_protocol_store() -> Result<InMemSignalProtocolStore, SignalProtocolError> {
let mut csprng = OsRng;
let identity_key = IdentityKeyPair::generate(&mut csprng);
let registration_id = 5; // fixme randomly generate this
// Valid registration IDs fit in 14 bits.
let registration_id: u8 = csprng.gen();
InMemSignalProtocolStore::new(identity_key, registration_id)
InMemSignalProtocolStore::new(identity_key, registration_id as u32)
}
#[allow(dead_code)]
pub async fn encrypt(
store: &mut InMemSignalProtocolStore,
remote_address: &ProtocolAddress,
@@ -39,7 +42,6 @@ pub async fn encrypt(
.await
}
#[allow(dead_code)]
pub async fn decrypt(
store: &mut InMemSignalProtocolStore,
remote_address: &ProtocolAddress,
@@ -60,7 +62,6 @@ pub async fn decrypt(
.await
}
#[allow(dead_code)]
pub async fn create_pre_key_bundle<R: Rng + CryptoRng>(
store: &mut dyn ProtocolStore,
mut csprng: &mut R,
@@ -115,7 +116,6 @@ pub async fn create_pre_key_bundle<R: Rng + CryptoRng>(
Ok(pre_key_bundle)
}
#[allow(dead_code)]
pub fn initialize_sessions_v3() -> Result<(SessionRecord, SessionRecord), SignalProtocolError> {
let mut csprng = OsRng;
let alice_identity = IdentityKeyPair::generate(&mut csprng);
@@ -141,7 +141,7 @@ pub fn initialize_sessions_v3() -> Result<(SessionRecord, SessionRecord), Signal
bob_base_key,
None,
bob_ephemeral_key,
None, // TODO: add kyber prekeys
None,
*alice_identity.identity_key(),
alice_base_key.public_key,
None,
@@ -151,3 +151,271 @@ pub fn initialize_sessions_v3() -> Result<(SessionRecord, SessionRecord), Signal
Ok((alice_session, bob_session))
}
pub fn initialize_sessions_v4() -> Result<(SessionRecord, SessionRecord), SignalProtocolError> {
let mut csprng = OsRng;
let alice_identity = IdentityKeyPair::generate(&mut csprng);
let bob_identity = IdentityKeyPair::generate(&mut csprng);
let alice_base_key = KeyPair::generate(&mut csprng);
let bob_base_key = KeyPair::generate(&mut csprng);
let bob_ephemeral_key = bob_base_key;
let bob_kyber_key = kem::KeyPair::generate(kem::KeyType::Kyber1024);
let alice_params = AliceSignalProtocolParameters::new(
alice_identity,
alice_base_key,
*bob_identity.identity_key(),
bob_base_key.public_key,
bob_ephemeral_key.public_key,
)
.with_their_kyber_pre_key(&bob_kyber_key.public_key);
let alice_session = initialize_alice_session_record(&alice_params, &mut csprng)?;
let kyber_ciphertext = {
let bytes = alice_session
.get_kyber_ciphertext()?
.expect("has kyber ciphertext")
.clone();
bytes.into_boxed_slice()
};
let bob_params = BobSignalProtocolParameters::new(
bob_identity,
bob_base_key,
None,
bob_ephemeral_key,
Some(bob_kyber_key),
*alice_identity.identity_key(),
alice_base_key.public_key,
Some(&kyber_ciphertext),
);
let bob_session = initialize_bob_session_record(&bob_params)?;
Ok((alice_session, bob_session))
}
pub enum IdChoice {
Exactly(u32),
Next,
Random,
}
impl From<u32> for IdChoice {
fn from(id: u32) -> Self {
IdChoice::Exactly(id)
}
}
pub struct TestStoreBuilder {
rng: OsRng,
pub(crate) store: InMemSignalProtocolStore,
id_range: RangeFrom<u32>,
}
impl TestStoreBuilder {
pub fn new() -> Self {
let mut rng = OsRng;
let identity_key = IdentityKeyPair::generate(&mut rng);
// Valid registration IDs fit in 14 bits.
let registration_id: u8 = rng.gen();
let store = InMemSignalProtocolStore::new(identity_key, registration_id as u32)
.expect("can create store");
Self {
rng,
store,
id_range: 0..,
}
}
pub fn from_store(store: &InMemSignalProtocolStore) -> Self {
Self {
rng: OsRng,
store: store.clone(),
id_range: 0..,
}
}
pub fn with_pre_key(mut self, id_choice: IdChoice) -> Self {
self.add_pre_key(id_choice);
self
}
pub fn add_pre_key(&mut self, id_choice: IdChoice) {
let id = self.gen_id(id_choice);
// TODO: this requirement can be removed if store returns ids in the insertion order
if let Some(latest_id) = self.store.all_pre_key_ids().last() {
assert!(id > (*latest_id).into(), "Pre key ids should be increasing");
}
let pair = KeyPair::generate(&mut self.rng);
self.store
.save_pre_key(id.into(), &PreKeyRecord::new(id.into(), &pair), None)
.now_or_never()
.expect("sync")
.expect("able to store pre key");
}
pub fn with_signed_pre_key(mut self, id_choice: IdChoice) -> Self {
self.add_signed_pre_key(id_choice);
self
}
pub fn add_signed_pre_key(&mut self, id_choice: IdChoice) {
let id = self.gen_id(id_choice);
if let Some(latest_id) = self.store.all_signed_pre_key_ids().last() {
assert!(
id > (*latest_id).into(),
"Signed pre key ids should be increasing"
);
}
let pair = KeyPair::generate(&mut self.rng);
let public = pair.public_key.serialize();
let signature = self.sign(&public);
let record = SignedPreKeyRecord::new(id.into(), 42, &pair, &signature);
self.store
.save_signed_pre_key(id.into(), &record, None)
.now_or_never()
.expect("sync")
.expect("able to store signed pre key");
}
pub fn with_kyber_pre_key(mut self, id_choice: IdChoice) -> Self {
self.add_kyber_pre_key(id_choice);
self
}
pub fn add_kyber_pre_key(&mut self, id_choice: IdChoice) {
let id = self.gen_id(id_choice);
if let Some(latest_id) = self.store.all_kyber_pre_key_ids().last() {
assert!(
id > (*latest_id).into(),
"Signed pre key ids should be increasing"
);
}
let pair = kem::KeyPair::generate(kem::KeyType::Kyber1024);
let public = pair.public_key.serialize();
let signature = self.sign(&public);
let record = KyberPreKeyRecord::new(id.into(), 43, &pair, &signature);
self.store
.save_kyber_pre_key(id.into(), &record, None)
.now_or_never()
.expect("sync")
.expect("able toe store kyber pre key");
}
pub fn make_bundle_with_latest_keys(&self, device_id: DeviceId) -> PreKeyBundle {
let registration_id = self
.store
.get_local_registration_id(None)
.now_or_never()
.expect("sync")
.expect("contains local registration id");
let maybe_pre_key_record = self.store.all_pre_key_ids().max().map(|id| {
self.store
.pre_key_store
.get_pre_key(*id, None)
.now_or_never()
.expect("syng")
.expect("has pre key")
});
let identity_key_pair = self
.store
.get_identity_key_pair(None)
.now_or_never()
.expect("sync")
.expect("has identity key pair");
let identity_key = identity_key_pair.identity_key();
let signed_pre_key_record = self
.store
.all_signed_pre_key_ids()
.max()
.map(|id| {
self.store
.get_signed_pre_key(*id, None)
.now_or_never()
.expect("sync")
.expect("has signed pre key")
})
.expect("contains at least one signed pre key");
let maybe_kyber_pre_key_record = self.store.all_kyber_pre_key_ids().max().map(|id| {
self.store
.get_kyber_pre_key(*id, None)
.now_or_never()
.expect("sync")
.expect("has kyber pre key")
});
let mut bundle = PreKeyBundle::new(
registration_id,
device_id,
maybe_pre_key_record.map(|rec| {
(
rec.id().expect("has id"),
rec.public_key().expect("has public key"),
)
}),
signed_pre_key_record.id().expect("has id"),
signed_pre_key_record.public_key().expect("has public key"),
signed_pre_key_record.signature().expect("has signature"),
*identity_key,
)
.expect("can make pre key bundle from store");
if let Some(rec) = maybe_kyber_pre_key_record {
bundle = bundle.with_kyber_pre_key(
rec.id().expect("has id"),
rec.public_key().expect("has public key"),
rec.signature().expect("has signature"),
);
}
bundle
}
fn sign(&mut self, message: &[u8]) -> Box<[u8]> {
let identity_key_pair = self
.store
.get_identity_key_pair(None)
.now_or_never()
.expect("sync")
.expect("able to get identity");
let signing_key = identity_key_pair.private_key();
signing_key
.calculate_signature(message, &mut self.rng)
.expect("able to sign with identity key")
}
fn next_id(&mut self) -> u32 {
self.id_range.next().expect("should have enough ids")
}
fn gen_id(&mut self, choice: IdChoice) -> u32 {
match choice {
IdChoice::Exactly(id) => id,
// TODO: check the maximal existing id and continue from it
IdChoice::Next => self.next_id(),
IdChoice::Random => self.rng.gen(),
}
}
}
pub trait HasSessionVersion {
fn session_version(&self, address: &ProtocolAddress) -> Result<u32, SignalProtocolError>;
}
impl HasSessionVersion for TestStoreBuilder {
fn session_version(&self, address: &ProtocolAddress) -> Result<u32, SignalProtocolError> {
self.store.session_version(address)
}
}
impl HasSessionVersion for InMemSignalProtocolStore {
fn session_version(&self, address: &ProtocolAddress) -> Result<u32, SignalProtocolError> {
self.load_session(address, None)
.now_or_never()
.expect("sync")?
.expect("session found")
.session_version()
}
}