diff --git a/rust/protocol/src/double_ratchet.rs b/rust/protocol/src/double_ratchet.rs new file mode 100644 index 000000000..15aa0c0f0 --- /dev/null +++ b/rust/protocol/src/double_ratchet.rs @@ -0,0 +1,381 @@ +// +// Copyright 2026 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +//! Double Ratchet implementation. +//! +//! The root key, sender chain, and counters are deserialized into typed +//! fields. Receiver chains are kept in protobuf form and deserialized +//! lazily on demand. +//! +//! References: +//! - [Double Ratchet spec](https://signal.org/docs/specifications/doubleratchet/) + +use rand::{CryptoRng, Rng}; + +use crate::proto::storage::{SessionStructure, session_structure}; +use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey}; +use crate::state::InvalidSessionError; +use crate::{ + CiphertextMessageType, KeyPair, PrivateKey, PublicKey, Result, SignalProtocolError, consts, +}; + +// ── State ──────────────────────────────────────────────────────────── + +/// The state of a Double Ratchet session. +/// +/// Contains the root key, sending and receiving chains, and cached +/// message keys for out-of-order messages. Receiver chains are stored +/// in protobuf form and deserialized lazily on demand. +#[derive(Clone)] +pub(crate) struct RatchetState { + pub root_key: RootKey, + pub sender_chain: Option, + pub receiver_chains: Vec, + pub previous_counter: u32, + /// Maximum number of messages we'll skip ahead in a single chain. + /// Set to `consts::MAX_FORWARD_JUMPS` for normal sessions, + /// `usize::MAX` for self-sessions (note-to-self), following the + /// same pattern as SPQR's `max_jump` chain parameter. + pub max_forward_jumps: usize, +} + +/// The sending side of a ratchet: our current ephemeral key pair and +/// the chain key for encrypting outgoing messages. +#[derive(Clone)] +pub(crate) struct SenderChain { + pub ratchet_key: KeyPair, + pub chain_key: ChainKey, +} + +// Receiver chains are stored as raw `session_structure::Chain` protobuf in +// `RatchetState::receiver_chains`. Only the ratchet key and chain key are +// deserialized on demand; skipped message keys stay in protobuf form and are +// only deserialized individually when a matching counter is found. + +// ── Serialization bridge ────────────────────────────────────────────── + +impl RatchetState { + /// Deserialize the ratchet-relevant fields from a `SessionStructure`. + /// + /// Only reads root key, counters, and sender chain from `session`. + /// Receiver chains are passed in separately (moved, not cloned) by + /// the caller. + /// + /// Identity keys, registration IDs, pending pre-key state, and SPQR + /// state are owned by higher layers and are not touched here. + /// + /// `self_session` must be computed by the caller (requires identity + /// key comparison, which is not ratchet-layer knowledge). + pub(crate) fn from_pb( + session: &SessionStructure, + self_session: bool, + receiver_chains: Vec, + ) -> std::result::Result { + let root_key_bytes: [u8; 32] = session + .root_key + .as_slice() + .try_into() + .map_err(|_| InvalidSessionError("invalid root key"))?; + + let sender_chain = session + .sender_chain + .as_ref() + .map(SenderChain::from_pb) + .transpose()?; + + Ok(Self { + root_key: RootKey::new(root_key_bytes), + sender_chain, + receiver_chains, + previous_counter: session.previous_counter, + max_forward_jumps: if self_session { + usize::MAX + } else { + consts::MAX_FORWARD_JUMPS + }, + }) + } + + /// Write the ratchet state back into a `SessionStructure`. + /// + /// Only updates the ratchet-relevant fields; all other fields in + /// `session` are left unchanged. + pub(crate) fn apply_to_pb(self, session: &mut SessionStructure) { + let Self { + root_key, + sender_chain, + receiver_chains, + previous_counter, + max_forward_jumps: _, // not serialized; derived from session context + } = self; + session.root_key = root_key.key().to_vec(); + session.previous_counter = previous_counter; + session.sender_chain = sender_chain.as_ref().map(SenderChain::to_pb); + session.receiver_chains = receiver_chains; + } +} + +impl SenderChain { + fn from_pb(chain: &session_structure::Chain) -> std::result::Result { + let public_key = PublicKey::deserialize(&chain.sender_ratchet_key) + .map_err(|_| InvalidSessionError("invalid sender ratchet public key"))?; + let private_key = PrivateKey::deserialize(&chain.sender_ratchet_key_private) + .map_err(|_| InvalidSessionError("invalid sender ratchet private key"))?; + let chain_key = ChainKey::from_pb( + chain + .chain_key + .as_ref() + .ok_or(InvalidSessionError("missing sender chain key"))?, + )?; + Ok(Self { + ratchet_key: KeyPair { + public_key, + private_key, + }, + chain_key, + }) + } + + fn to_pb(&self) -> session_structure::Chain { + session_structure::Chain { + sender_ratchet_key: self.ratchet_key.public_key.serialize().to_vec(), + sender_ratchet_key_private: self.ratchet_key.private_key.serialize().to_vec(), + chain_key: Some(self.chain_key.to_pb()), + message_keys: vec![], + } + } +} + +impl ChainKey { + fn from_pb( + pb: &session_structure::chain::ChainKey, + ) -> std::result::Result { + let key: [u8; 32] = pb + .key + .as_slice() + .try_into() + .map_err(|_| InvalidSessionError("invalid chain key"))?; + Ok(Self::new(key, pb.index)) + } + + fn to_pb(&self) -> session_structure::chain::ChainKey { + session_structure::chain::ChainKey { + index: self.index(), + key: self.key().to_vec(), + } + } +} + +// ── Operations ─────────────────────────────────────────────────────── + +impl RatchetState { + fn take_root_key(&mut self) -> RootKey { + std::mem::replace(&mut self.root_key, RootKey::new([0; 32])) + } + + /// Ensure a receiver chain exists for a remote ephemeral key, returning + /// its chain key. + /// + /// If we already have a receiver chain for `their_ephemeral`, return + /// its chain key. Otherwise, perform a DH ratchet step: derive a new + /// receiver chain from the current root key and the remote ephemeral, + /// then generate a fresh sender chain. + pub fn ensure_receiver_chain( + &mut self, + their_ephemeral: &PublicKey, + csprng: &mut R, + ) -> Result { + if let Some(chain_key) = self.find_receiver_chain_key(their_ephemeral)? { + return Ok(chain_key); + } + + self.dh_ratchet_step(their_ephemeral, csprng) + } + + /// Consume the message key for a specific counter value. + /// + /// If the counter is behind the current chain index, take a cached + /// (skipped) key. If it's ahead, advance the chain — caching the + /// intermediate keys for out-of-order delivery — up to `max_forward_jumps`. + /// Either way, the key is consumed and cannot be retrieved again. + pub fn consume_message_key( + &mut self, + their_ephemeral: &PublicKey, + mut chain_key: ChainKey, + counter: u32, + // The original message type, for error reporting only. + original_message_type: CiphertextMessageType, + remote_address_for_logging: &str, + ) -> Result { + let chain_index = chain_key.index(); + + if chain_index > counter { + // Counter is in the past — look up a cached key. + return self + .take_skipped_key(their_ephemeral, counter)? + .ok_or_else(|| { + log::info!( + "{remote_address_for_logging} Duplicate message for counter: {counter}" + ); + SignalProtocolError::DuplicatedMessage(chain_index, counter) + }); + } + + let jump = (counter - chain_index) as usize; + if jump > self.max_forward_jumps { + log::error!( + "{remote_address_for_logging} Exceeded future message limit: {}, index: {chain_index}, counter: {counter}", + self.max_forward_jumps, + ); + return Err(SignalProtocolError::InvalidMessage( + original_message_type, + "message from too far into the future", + )); + } else if jump > consts::MAX_FORWARD_JUMPS { + // This only happens if it is a session with self + log::info!( + "{remote_address_for_logging} Jumping ahead {jump} messages (index: {chain_index}, counter: {counter})" + ); + } + + // Advance the chain to the target counter, caching skipped keys. + while chain_key.index() < counter { + self.store_skipped_key(their_ephemeral, chain_key.message_keys()); + chain_key = chain_key.next_chain_key(); + } + + // Update the receiver chain to the next key past the one we're returning. + self.set_receiver_chain_key(their_ephemeral, chain_key.next_chain_key()); + + Ok(chain_key.message_keys()) + } + + // ── Internals ──────────────────────────────────────────────────── + + fn find_receiver_chain_key( + &self, + their_ephemeral: &PublicKey, + ) -> std::result::Result, InvalidSessionError> { + let Some(idx) = self.find_receiver_chain_index(their_ephemeral) else { + return Ok(None); + }; + let chain_key_pb = self.receiver_chains[idx] + .chain_key + .as_ref() + .ok_or(InvalidSessionError("missing receiver chain key"))?; + Ok(Some(ChainKey::from_pb(chain_key_pb)?)) + } + + fn dh_ratchet_step( + &mut self, + their_ephemeral: &PublicKey, + csprng: &mut R, + ) -> Result { + let sender_private_key = self + .sender_chain + .as_ref() + .ok_or(InvalidSessionError("missing sender chain"))? + .ratchet_key + .private_key; + + // Receiving half-step: root_key + DH(our_sender, their_ephemeral) + let current_root_key = self.take_root_key(); + let (new_root_key, receiver_chain_key) = + current_root_key.create_chain(their_ephemeral, &sender_private_key)?; + + // Sending half-step: new_root_key + DH(new_ephemeral, their_ephemeral) + let new_sender_key = KeyPair::generate(csprng); + let (final_root_key, sender_chain_key) = + new_root_key.create_chain(their_ephemeral, &new_sender_key.private_key)?; + + // Record the previous sender chain counter before we replace it. + let current_index = self + .sender_chain + .as_ref() + .expect("checked above") + .chain_key + .index(); + self.previous_counter = current_index.saturating_sub(1); + + self.root_key = final_root_key; + + self.receiver_chains.push(session_structure::Chain { + sender_ratchet_key: their_ephemeral.serialize().to_vec(), + sender_ratchet_key_private: vec![], + chain_key: Some(receiver_chain_key.to_pb()), + message_keys: vec![], + }); + while self.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS { + self.receiver_chains.remove(0); + } + + self.sender_chain = Some(SenderChain { + ratchet_key: new_sender_key, + chain_key: sender_chain_key, + }); + + Ok(receiver_chain_key) + } + + fn take_skipped_key( + &mut self, + their_ephemeral: &PublicKey, + counter: u32, + ) -> std::result::Result, InvalidSessionError> { + let Some(chain_idx) = self.find_receiver_chain_index(their_ephemeral) else { + return Ok(None); + }; + let keys = &mut self.receiver_chains[chain_idx].message_keys; + + // Scan by protobuf index field — no deserialization needed for non-matches. + let Some(pos) = keys.iter().position(|mk| mk.index == counter) else { + return Ok(None); + }; + + // Remove before deserializing. If from_pb fails the key was corrupt + // and unrecoverable — slightly different from main which preserves it, + // but the outcome (error) is the same. + let key_pb = keys.remove(pos); + MessageKeyGenerator::from_pb(key_pb) + .map(Some) + .map_err(InvalidSessionError) + } + + fn store_skipped_key(&mut self, their_ephemeral: &PublicKey, key: MessageKeyGenerator) { + let chain_idx = self + .find_receiver_chain_index(their_ephemeral) + .expect("store_skipped_key called for non-existent chain"); + let keys = &mut self.receiver_chains[chain_idx].message_keys; + + // TODO: This insert(0) is O(n), making the skip-ahead loop O(n²). + // We could switch to push() (appending newest at end) with rposition() + // for search and remove(0) for trimming. That makes the common path + // O(1) per key. Deferred because it changes serialized key order, + // breaking bit-for-bit compatibility with the legacy implementation. + keys.insert(0, key.into_pb()); + if keys.len() > consts::MAX_MESSAGE_KEYS { + keys.pop(); + } + } + + fn set_receiver_chain_key(&mut self, their_ephemeral: &PublicKey, chain_key: ChainKey) { + let chain_idx = self + .find_receiver_chain_index(their_ephemeral) + .expect("set_receiver_chain_key called for non-existent chain"); + self.receiver_chains[chain_idx].chain_key = Some(chain_key.to_pb()); + } + + fn find_receiver_chain_index(&self, their_ephemeral: &PublicKey) -> Option { + self.receiver_chains.iter().position(|chain| { + match PublicKey::deserialize(&chain.sender_ratchet_key) { + Ok(key) => &key == their_ephemeral, + Err(_) => { + log::warn!("skipping corrupt receiver chain with invalid ratchet key"); + false + } + } + }) + } +} diff --git a/rust/protocol/src/handshake.rs b/rust/protocol/src/handshake.rs new file mode 100644 index 000000000..a43a198f9 --- /dev/null +++ b/rust/protocol/src/handshake.rs @@ -0,0 +1,56 @@ +// +// Copyright 2026 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +//! Handshake trait for key agreement protocols. +//! +//! Abstracts over different key agreement protocols (PQXDH, and hypothetical +//! future variants). The trait separates key agreement from ratchet +//! initialization and session management. +//! +//! See [`crate::pqxdh`] for the current production implementation. + +use rand::{CryptoRng, Rng}; + +use crate::Result; + +/// A key agreement protocol used to establish a shared secret during +/// session initialization. +/// +/// Implementors handle the cryptographic key agreement (DH computations, +/// KEM encapsulation/decapsulation, KDF). The resulting session secret is +/// consumed by the ratchet layer to initialize session state. The initiator +/// also produces a message that must be sent to the recipient for them to +/// complete the agreement. +/// +/// The `initiate` method returns `(InitiatorMessage, SessionSecret)` to +/// enforce a clean boundary: the message is data for the wire, the secret +/// is data for the ratchet. These are currently protocol-specific types +/// but should eventually become opaque byte arrays. +pub(crate) trait Handshake { + /// Parameters for the initiator (constructed from a pre-key bundle). + type InitiatorParams; + + /// Parameters for the recipient (constructed from an incoming message). + type RecipientParams<'a>; + + /// Data the initiator must send to the recipient for them to complete + /// the key agreement (e.g., a KEM ciphertext for PQXDH). + type InitiatorMessage; + + /// The shared secret derived from the handshake, consumed by the + /// ratchet layer to initialize session state. + type SessionSecret; + + /// Perform the initiator side of the key agreement. + /// + /// Returns the message to send and the session secret to keep. + fn initiate( + params: &Self::InitiatorParams, + rng: &mut R, + ) -> Result<(Self::InitiatorMessage, Self::SessionSecret)>; + + /// Perform the recipient side of the key agreement. + fn accept(params: &Self::RecipientParams<'_>) -> Result; +} diff --git a/rust/protocol/src/lib.rs b/rust/protocol/src/lib.rs index 49a271fe4..c739fb6e8 100644 --- a/rust/protocol/src/lib.rs +++ b/rust/protocol/src/lib.rs @@ -24,22 +24,28 @@ mod consts; mod crypto; +mod double_ratchet; pub mod error; mod fingerprint; mod group_cipher; +mod handshake; mod identity_key; pub mod incremental_mac; pub mod kem; +pub mod pqxdh; mod proto; mod protocol; mod ratchet; mod sealed_sender; mod sender_keys; mod session; -mod session_cipher; +#[cfg(test)] +mod session_cipher_legacy; +mod session_management; mod state; mod storage; mod timestamp; +mod triple_ratchet; use error::Result; pub use error::SignalProtocolError; @@ -72,7 +78,7 @@ pub use sealed_sender::{ }; pub use sender_keys::SenderKeyRecord; pub use session::{process_prekey, process_prekey_bundle}; -pub use session_cipher::{ +pub use session_management::{ message_decrypt, message_decrypt_prekey, message_decrypt_signal, message_encrypt, }; pub use state::{ diff --git a/rust/protocol/src/pqxdh.rs b/rust/protocol/src/pqxdh.rs new file mode 100644 index 000000000..5d59b1fcc --- /dev/null +++ b/rust/protocol/src/pqxdh.rs @@ -0,0 +1,358 @@ +// +// Copyright 2026 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +//! PQXDH key agreement protocol. +//! +//! This module implements the PQXDH (Post-Quantum Extended Diffie-Hellman) key +//! agreement, extracting the pure key agreement computation from ratchet +//! initialization. The output includes derived keys ready for ratchet setup; +//! the actual ratchet initialization is handled separately in the internal +//! ratchet module. +//! +//! ## Future direction +//! +//! The KDF output shape (`RootKey`, `ChainKey`, `[u8; 32]`) is currently +//! coupled to the Double Ratchet's initialization requirements. Ideally, the +//! handshake would output a single 32-byte secret and the ratchet layer would +//! derive whatever it needs from that. This requires a protocol version bump +//! and should be done alongside a future handshake protocol revision. + +use libsignal_core::derive_arrays; +use rand::{CryptoRng, Rng}; + +use crate::handshake::Handshake; +use crate::ratchet::{ChainKey, RootKey}; +use crate::{ + CiphertextMessageType, IdentityKey, IdentityKeyPair, KeyPair, PublicKey, Result, + SignalProtocolError, kem, +}; + +/// The PQXDH key agreement protocol. +/// +/// Implements [`Handshake`] for the Post-Quantum Extended Diffie-Hellman +/// protocol (4 EC DH + 1 ML-KEM encapsulation/decapsulation). +pub(crate) struct Pqxdh; + +impl Handshake for Pqxdh { + type InitiatorParams = InitiatorParameters; + type RecipientParams<'a> = RecipientParameters<'a>; + type InitiatorMessage = kem::SerializedCiphertext; + type SessionSecret = HandshakeKeys; + + fn initiate( + params: &Self::InitiatorParams, + rng: &mut R, + ) -> Result<(Self::InitiatorMessage, Self::SessionSecret)> { + let result = pqxdh_initiate(params, rng)?; + Ok((result.kyber_ciphertext, result.keys)) + } + + fn accept(params: &Self::RecipientParams<'_>) -> Result { + pqxdh_accept(params) + } +} + +/// The initial PQR (post-quantum ratchet) key derived from the handshake. +pub(crate) type InitialPQRKey = [u8; 32]; + +/// Keys derived from a PQXDH handshake, ready for ratchet initialization. +/// +/// This bundles the KDF output in the shape the ratchet layer expects. +/// See module-level docs for why this is coupled and the plan to decouple. +pub(crate) struct HandshakeKeys { + pub root_key: RootKey, + pub chain_key: ChainKey, + pub pqr_key: InitialPQRKey, +} + +impl HandshakeKeys { + /// Derive ratchet initialization keys from raw PQXDH shared secret material. + fn derive(secret_input: &[u8]) -> Self { + Self::derive_with_label( + b"WhisperText_X25519_SHA-256_CRYSTALS-KYBER-1024", + secret_input, + ) + } + + fn derive_with_label(label: &[u8], secret_input: &[u8]) -> Self { + let (root_key_bytes, chain_key_bytes, pqr_bytes) = derive_arrays(|bytes| { + hkdf::Hkdf::::new(None, secret_input) + .expand(label, bytes) + .expect("valid length") + }); + + Self { + root_key: RootKey::new(root_key_bytes), + chain_key: ChainKey::new(chain_key_bytes, 0), + pqr_key: pqr_bytes, + } + } +} + +// ── Initiator ──────────────────────────────────────────────────────── + +/// The output of a PQXDH key agreement from the initiator's side. +/// +/// Contains the derived handshake keys and the KEM ciphertext that the +/// recipient needs to complete the agreement. +pub(crate) struct InitiatorAgreement { + keys: HandshakeKeys, + kyber_ciphertext: kem::SerializedCiphertext, +} + +/// Parameters for the initiator side of a PQXDH key agreement. +/// +/// The initiator fetches the recipient's pre-key bundle from the server +/// and uses it together with their own identity and ephemeral keys to +/// compute a shared secret. +pub struct InitiatorParameters { + our_identity_key_pair: IdentityKeyPair, + our_ephemeral_key_pair: KeyPair, + + their_identity_key: IdentityKey, + their_signed_pre_key: PublicKey, + their_one_time_pre_key: Option, + their_ratchet_key: PublicKey, + their_kyber_pre_key: kem::PublicKey, +} + +impl InitiatorParameters { + pub fn new( + our_identity_key_pair: IdentityKeyPair, + our_ephemeral_key_pair: KeyPair, + their_identity_key: IdentityKey, + their_signed_pre_key: PublicKey, + their_ratchet_key: PublicKey, + their_kyber_pre_key: kem::PublicKey, + ) -> Self { + Self { + our_identity_key_pair, + our_ephemeral_key_pair, + their_identity_key, + their_one_time_pre_key: None, + their_signed_pre_key, + their_ratchet_key, + their_kyber_pre_key, + } + } + + pub fn set_their_one_time_pre_key(&mut self, ec_public: PublicKey) { + self.their_one_time_pre_key = Some(ec_public); + } + + #[inline] + pub fn our_identity_key_pair(&self) -> &IdentityKeyPair { + &self.our_identity_key_pair + } + + #[inline] + pub fn our_ephemeral_key_pair(&self) -> &KeyPair { + &self.our_ephemeral_key_pair + } + + #[inline] + pub fn their_identity_key(&self) -> &IdentityKey { + &self.their_identity_key + } + + #[inline] + pub fn their_signed_pre_key(&self) -> &PublicKey { + &self.their_signed_pre_key + } + + #[inline] + pub fn their_one_time_pre_key(&self) -> Option<&PublicKey> { + self.their_one_time_pre_key.as_ref() + } + + #[inline] + pub fn their_kyber_pre_key(&self) -> &kem::PublicKey { + &self.their_kyber_pre_key + } + + #[inline] + pub fn their_ratchet_key(&self) -> &PublicKey { + &self.their_ratchet_key + } +} + +/// Perform the initiator side of the PQXDH key agreement. +/// +/// Computes DH shared secrets and KEM encapsulation, then applies the KDF +/// to produce keys ready for ratchet initialization. +pub(crate) fn pqxdh_initiate( + parameters: &InitiatorParameters, + mut csprng: &mut R, +) -> Result { + let mut secrets = Vec::with_capacity(32 * 6); + + secrets.extend_from_slice(&[0xFFu8; 32]); // discontinuity bytes + + secrets.extend_from_slice( + ¶meters + .our_identity_key_pair + .private_key() + .calculate_agreement(¶meters.their_signed_pre_key)?, + ); + + let our_ephemeral_private_key = parameters.our_ephemeral_key_pair.private_key; + + secrets.extend_from_slice( + &our_ephemeral_private_key + .calculate_agreement(parameters.their_identity_key.public_key())?, + ); + + secrets.extend_from_slice( + &our_ephemeral_private_key.calculate_agreement(¶meters.their_signed_pre_key)?, + ); + + if let Some(their_one_time_prekey) = ¶meters.their_one_time_pre_key { + secrets.extend_from_slice( + &our_ephemeral_private_key.calculate_agreement(their_one_time_prekey)?, + ); + } + + let kyber_ciphertext = { + let (ss, ct) = parameters.their_kyber_pre_key.encapsulate(&mut csprng)?; + secrets.extend_from_slice(ss.as_ref()); + ct + }; + + Ok(InitiatorAgreement { + keys: HandshakeKeys::derive(&secrets), + kyber_ciphertext, + }) +} + +// ── Recipient ──────────────────────────────────────────────────────── + +/// Parameters for the recipient side of a PQXDH key agreement. +/// +/// The recipient uses their own pre-keys together with the initiator's +/// identity and base keys (received in the pre-key message) to compute +/// the same shared secret. +pub struct RecipientParameters<'a> { + our_identity_key_pair: IdentityKeyPair, + our_signed_pre_key_pair: KeyPair, + our_one_time_pre_key_pair: Option, + our_kyber_pre_key_pair: kem::KeyPair, + + their_identity_key: IdentityKey, + their_ephemeral_key: PublicKey, + their_kyber_ciphertext: &'a kem::SerializedCiphertext, +} + +impl<'a> RecipientParameters<'a> { + pub fn new( + our_identity_key_pair: IdentityKeyPair, + our_signed_pre_key_pair: KeyPair, + our_one_time_pre_key_pair: Option, + our_kyber_pre_key_pair: kem::KeyPair, + their_identity_key: IdentityKey, + their_ephemeral_key: PublicKey, + their_kyber_ciphertext: &'a kem::SerializedCiphertext, + ) -> Self { + Self { + our_identity_key_pair, + our_signed_pre_key_pair, + our_one_time_pre_key_pair, + our_kyber_pre_key_pair, + their_identity_key, + their_ephemeral_key, + their_kyber_ciphertext, + } + } + + #[inline] + pub fn our_identity_key_pair(&self) -> &IdentityKeyPair { + &self.our_identity_key_pair + } + + #[inline] + pub fn our_signed_pre_key_pair(&self) -> &KeyPair { + &self.our_signed_pre_key_pair + } + + #[inline] + pub fn our_one_time_pre_key_pair(&self) -> Option<&KeyPair> { + self.our_one_time_pre_key_pair.as_ref() + } + + #[inline] + pub fn our_kyber_pre_key_pair(&self) -> &kem::KeyPair { + &self.our_kyber_pre_key_pair + } + + #[inline] + pub fn their_identity_key(&self) -> &IdentityKey { + &self.their_identity_key + } + + #[inline] + pub fn their_ephemeral_key(&self) -> &PublicKey { + &self.their_ephemeral_key + } + + #[inline] + pub fn their_kyber_ciphertext(&self) -> &kem::SerializedCiphertext { + self.their_kyber_ciphertext + } +} + +/// Perform the recipient side of the PQXDH key agreement. +/// +/// Computes DH shared secrets and KEM decapsulation, then applies the KDF +/// to produce keys ready for ratchet initialization. +pub(crate) fn pqxdh_accept(parameters: &RecipientParameters) -> Result { + // Validate the initiator's base key before doing any computation. + if !parameters.their_ephemeral_key.is_canonical() { + return Err(SignalProtocolError::InvalidMessage( + CiphertextMessageType::PreKey, + "incoming base key is invalid", + )); + } + + let mut secrets = Vec::with_capacity(32 * 6); + + secrets.extend_from_slice(&[0xFFu8; 32]); // discontinuity bytes + + secrets.extend_from_slice( + ¶meters + .our_signed_pre_key_pair + .private_key + .calculate_agreement(parameters.their_identity_key.public_key())?, + ); + + secrets.extend_from_slice( + ¶meters + .our_identity_key_pair + .private_key() + .calculate_agreement(¶meters.their_ephemeral_key)?, + ); + + secrets.extend_from_slice( + ¶meters + .our_signed_pre_key_pair + .private_key + .calculate_agreement(¶meters.their_ephemeral_key)?, + ); + + if let Some(our_one_time_pre_key_pair) = ¶meters.our_one_time_pre_key_pair { + secrets.extend_from_slice( + &our_one_time_pre_key_pair + .private_key + .calculate_agreement(¶meters.their_ephemeral_key)?, + ); + } + + secrets.extend_from_slice( + ¶meters + .our_kyber_pre_key_pair + .secret_key + .decapsulate(parameters.their_kyber_ciphertext)?, + ); + + Ok(HandshakeKeys::derive(&secrets)) +} diff --git a/rust/protocol/src/protocol.rs b/rust/protocol/src/protocol.rs index 5531b9301..4046c03c7 100644 --- a/rust/protocol/src/protocol.rs +++ b/rust/protocol/src/protocol.rs @@ -203,7 +203,7 @@ impl SignalMessage { let Some(expected) = Self::serialize_addresses(sender_address, recipient_address) else { log::warn!( - "Local addresses not valid Service IDs: sender={}, recipient={}", + "Locally supplied addresses not valid Service IDs: sender={}, recipient={}", sender_address, recipient_address, ); diff --git a/rust/protocol/src/ratchet.rs b/rust/protocol/src/ratchet.rs index 47f5e5ef9..faa9e1c3d 100644 --- a/rust/protocol/src/ratchet.rs +++ b/rust/protocol/src/ratchet.rs @@ -4,39 +4,25 @@ // mod keys; -mod params; -use libsignal_core::derive_arrays; use rand::{CryptoRng, Rng}; pub(crate) use self::keys::{ChainKey, MessageKeyGenerator, RootKey}; -pub use self::params::{AliceSignalProtocolParameters, BobSignalProtocolParameters}; +use crate::handshake::Handshake; +use crate::pqxdh::{HandshakeKeys, Pqxdh}; +// Re-export the parameter types for backward compatibility. +// Callers (session.rs, tests) use these via `ratchet::`. +pub use crate::pqxdh::{InitiatorParameters, RecipientParameters}; use crate::protocol::CIPHERTEXT_MESSAGE_CURRENT_VERSION; use crate::state::SessionState; use crate::{KeyPair, Result, SessionRecord, SignalProtocolError, consts}; -type InitialPQRKey = [u8; 32]; - -fn derive_keys(secret_input: &[u8]) -> (RootKey, ChainKey, InitialPQRKey) { - derive_keys_with_label( - b"WhisperText_X25519_SHA-256_CRYSTALS-KYBER-1024", - secret_input, - ) -} - -fn derive_keys_with_label(label: &[u8], secret_input: &[u8]) -> (RootKey, ChainKey, InitialPQRKey) { - let (root_key_bytes, chain_key_bytes, pqr_bytes) = derive_arrays(|bytes| { - hkdf::Hkdf::::new(None, secret_input) - .expand(label, bytes) - .expect("valid length") - }); - - let root_key = RootKey::new(root_key_bytes); - let chain_key = ChainKey::new(chain_key_bytes, 0); - let pqr_key: InitialPQRKey = pqr_bytes; - - (root_key, chain_key, pqr_key) -} +// Backward-compatible aliases for the old names. These keep existing +// external callers (tests, bridge code) compiling during the transition. +#[doc(hidden)] +pub type AliceSignalProtocolParameters = InitiatorParameters; +#[doc(hidden)] +pub type BobSignalProtocolParameters<'a> = RecipientParameters<'a>; fn spqr_chain_params(self_connection: bool) -> spqr::ChainParams { #[allow(clippy::needless_update)] @@ -51,47 +37,44 @@ fn spqr_chain_params(self_connection: bool) -> spqr::ChainParams { } } +/// Initialize a session from the initiator's side. +/// +/// Performs the PQXDH key agreement and then sets up the Double Ratchet +/// and SPQR state. pub(crate) fn initialize_alice_session( - parameters: &AliceSignalProtocolParameters, - mut csprng: &mut R, + parameters: &InitiatorParameters, + csprng: &mut R, +) -> Result { + let ( + kyber_ciphertext, + HandshakeKeys { + root_key, + chain_key, + pqr_key, + }, + ) = Pqxdh::initiate(parameters, csprng)?; + + initialize_initiator_session( + parameters, + root_key, + chain_key, + pqr_key, + kyber_ciphertext, + csprng, + ) +} + +fn initialize_initiator_session( + parameters: &InitiatorParameters, + root_key: RootKey, + chain_key: ChainKey, + pqr_key: [u8; 32], + kyber_ciphertext: crate::kem::SerializedCiphertext, + csprng: &mut R, ) -> Result { let local_identity = parameters.our_identity_key_pair().identity_key(); - let mut secrets = Vec::with_capacity(32 * 6); - - secrets.extend_from_slice(&[0xFFu8; 32]); // "discontinuity bytes" - - let our_base_private_key = parameters.our_base_key_pair().private_key; - - secrets.extend_from_slice( - ¶meters - .our_identity_key_pair() - .private_key() - .calculate_agreement(parameters.their_signed_pre_key())?, - ); - - secrets.extend_from_slice( - &our_base_private_key.calculate_agreement(parameters.their_identity_key().public_key())?, - ); - - secrets.extend_from_slice( - &our_base_private_key.calculate_agreement(parameters.their_signed_pre_key())?, - ); - - if let Some(their_one_time_prekey) = parameters.their_one_time_pre_key() { - secrets - .extend_from_slice(&our_base_private_key.calculate_agreement(their_one_time_prekey)?); - } - - let kyber_ciphertext = { - let (ss, ct) = parameters.their_kyber_pre_key().encapsulate(&mut csprng)?; - secrets.extend_from_slice(ss.as_ref()); - ct - }; - - let (root_key, chain_key, pqr_key) = derive_keys(&secrets); - - let sending_ratchet_key = KeyPair::generate(&mut csprng); + let sending_ratchet_key = KeyPair::generate(csprng); let (sending_chain_root_key, sending_chain_chain_key) = root_key.create_chain( parameters.their_ratchet_key(), &sending_ratchet_key.private_key, @@ -118,7 +101,7 @@ pub(crate) fn initialize_alice_session( local_identity, parameters.their_identity_key(), &sending_chain_root_key, - ¶meters.our_base_key_pair().public_key, + ¶meters.our_ephemeral_key_pair().public_key, pqr_state, ) .with_receiver_chain(parameters.their_ratchet_key(), &chain_key) @@ -129,61 +112,38 @@ pub(crate) fn initialize_alice_session( Ok(session) } +/// Initialize a session from the recipient's side. +/// +/// Performs the PQXDH key agreement and then sets up the Double Ratchet +/// and SPQR state. pub(crate) fn initialize_bob_session( - parameters: &BobSignalProtocolParameters, + parameters: &RecipientParameters, + our_ratchet_key_pair: &KeyPair, ) -> Result { - // validate their base key - if !parameters.their_base_key().is_canonical() { - return Err(SignalProtocolError::InvalidMessage( - crate::CiphertextMessageType::PreKey, - "incoming base key is invalid", - )); - } + let HandshakeKeys { + root_key, + chain_key, + pqr_key, + } = Pqxdh::accept(parameters)?; + initialize_recipient_session( + parameters, + our_ratchet_key_pair, + root_key, + chain_key, + pqr_key, + ) +} + +fn initialize_recipient_session( + parameters: &RecipientParameters, + our_ratchet_key_pair: &KeyPair, + root_key: RootKey, + chain_key: ChainKey, + pqr_key: [u8; 32], +) -> Result { let local_identity = parameters.our_identity_key_pair().identity_key(); - let mut secrets = Vec::with_capacity(32 * 6); - - secrets.extend_from_slice(&[0xFFu8; 32]); // "discontinuity bytes" - - secrets.extend_from_slice( - ¶meters - .our_signed_pre_key_pair() - .private_key - .calculate_agreement(parameters.their_identity_key().public_key())?, - ); - - secrets.extend_from_slice( - ¶meters - .our_identity_key_pair() - .private_key() - .calculate_agreement(parameters.their_base_key())?, - ); - - secrets.extend_from_slice( - ¶meters - .our_signed_pre_key_pair() - .private_key - .calculate_agreement(parameters.their_base_key())?, - ); - - if let Some(our_one_time_pre_key_pair) = parameters.our_one_time_pre_key_pair() { - secrets.extend_from_slice( - &our_one_time_pre_key_pair - .private_key - .calculate_agreement(parameters.their_base_key())?, - ); - } - - secrets.extend_from_slice( - ¶meters - .our_kyber_pre_key_pair() - .secret_key - .decapsulate(parameters.their_kyber_ciphertext())?, - ); - - let (root_key, chain_key, pqr_key) = derive_keys(&secrets); - let self_session = local_identity == parameters.their_identity_key(); let pqr_state = spqr::initial_state(spqr::Params { auth_key: &pqr_key, @@ -199,21 +159,22 @@ pub(crate) fn initialize_bob_session( "post-quantum ratchet: error creating initial B2A state: {e}" )) })?; + let session = SessionState::new( CIPHERTEXT_MESSAGE_CURRENT_VERSION, local_identity, parameters.their_identity_key(), &root_key, - parameters.their_base_key(), + parameters.their_ephemeral_key(), pqr_state, ) - .with_sender_chain(parameters.our_ratchet_key_pair(), &chain_key); + .with_sender_chain(our_ratchet_key_pair, &chain_key); Ok(session) } pub fn initialize_alice_session_record( - parameters: &AliceSignalProtocolParameters, + parameters: &InitiatorParameters, csprng: &mut R, ) -> Result { Ok(SessionRecord::new(initialize_alice_session( @@ -222,7 +183,11 @@ pub fn initialize_alice_session_record( } pub fn initialize_bob_session_record( - parameters: &BobSignalProtocolParameters, + parameters: &RecipientParameters, + our_ratchet_key_pair: &KeyPair, ) -> Result { - Ok(SessionRecord::new(initialize_bob_session(parameters)?)) + Ok(SessionRecord::new(initialize_bob_session( + parameters, + our_ratchet_key_pair, + )?)) } diff --git a/rust/protocol/src/ratchet/keys.rs b/rust/protocol/src/ratchet/keys.rs index 3e664eb00..495071500 100644 --- a/rust/protocol/src/ratchet/keys.rs +++ b/rust/protocol/src/ratchet/keys.rs @@ -10,11 +10,21 @@ use libsignal_core::derive_arrays; use crate::proto::storage::session_structure; use crate::{PrivateKey, PublicKey, Result, crypto}; +#[derive(Clone)] pub(crate) enum MessageKeyGenerator { Keys(MessageKeys), Seed((Vec, u32)), } +impl fmt::Debug for MessageKeyGenerator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Keys(k) => f.debug_tuple("Keys").field(&k.counter()).finish(), + Self::Seed((_, idx)) => f.debug_tuple("Seed").field(idx).finish(), + } + } +} + impl MessageKeyGenerator { pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self { Self::Seed((seed.to_vec(), counter)) diff --git a/rust/protocol/src/ratchet/params.rs b/rust/protocol/src/ratchet/params.rs deleted file mode 100644 index 7f60c6c0d..000000000 --- a/rust/protocol/src/ratchet/params.rs +++ /dev/null @@ -1,159 +0,0 @@ -// -// Copyright 2020 Signal Messenger, LLC. -// SPDX-License-Identifier: AGPL-3.0-only -// - -use crate::{IdentityKey, IdentityKeyPair, KeyPair, PublicKey, kem}; - -pub struct AliceSignalProtocolParameters { - our_identity_key_pair: IdentityKeyPair, - our_base_key_pair: KeyPair, - - their_identity_key: IdentityKey, - their_signed_pre_key: PublicKey, - their_one_time_pre_key: Option, - their_ratchet_key: PublicKey, - their_kyber_pre_key: kem::PublicKey, -} - -impl AliceSignalProtocolParameters { - pub fn new( - our_identity_key_pair: IdentityKeyPair, - our_base_key_pair: KeyPair, - their_identity_key: IdentityKey, - their_signed_pre_key: PublicKey, - their_ratchet_key: PublicKey, - their_kyber_pre_key: kem::PublicKey, - ) -> Self { - Self { - our_identity_key_pair, - our_base_key_pair, - their_identity_key, - their_signed_pre_key, - their_one_time_pre_key: None, - their_ratchet_key, - their_kyber_pre_key, - } - } - - pub fn set_their_one_time_pre_key(&mut self, ec_public: PublicKey) { - self.their_one_time_pre_key = Some(ec_public); - } - - pub fn with_their_one_time_pre_key(mut self, ec_public: PublicKey) -> Self { - self.set_their_one_time_pre_key(ec_public); - self - } - - #[inline] - pub fn our_identity_key_pair(&self) -> &IdentityKeyPair { - &self.our_identity_key_pair - } - - #[inline] - pub fn our_base_key_pair(&self) -> &KeyPair { - &self.our_base_key_pair - } - - #[inline] - pub fn their_identity_key(&self) -> &IdentityKey { - &self.their_identity_key - } - - #[inline] - pub fn their_signed_pre_key(&self) -> &PublicKey { - &self.their_signed_pre_key - } - - #[inline] - pub fn their_one_time_pre_key(&self) -> Option<&PublicKey> { - self.their_one_time_pre_key.as_ref() - } - - #[inline] - pub fn their_kyber_pre_key(&self) -> &kem::PublicKey { - &self.their_kyber_pre_key - } - - #[inline] - pub fn their_ratchet_key(&self) -> &PublicKey { - &self.their_ratchet_key - } -} - -pub struct BobSignalProtocolParameters<'a> { - our_identity_key_pair: IdentityKeyPair, - our_signed_pre_key_pair: KeyPair, - our_one_time_pre_key_pair: Option, - our_ratchet_key_pair: KeyPair, - our_kyber_pre_key_pair: kem::KeyPair, - - their_identity_key: IdentityKey, - their_base_key: PublicKey, - their_kyber_ciphertext: &'a kem::SerializedCiphertext, -} - -impl<'a> BobSignalProtocolParameters<'a> { - #[allow(clippy::too_many_arguments)] - pub fn new( - our_identity_key_pair: IdentityKeyPair, - our_signed_pre_key_pair: KeyPair, - our_one_time_pre_key_pair: Option, - our_ratchet_key_pair: KeyPair, - our_kyber_pre_key_pair: kem::KeyPair, - their_identity_key: IdentityKey, - their_base_key: PublicKey, - their_kyber_ciphertext: &'a kem::SerializedCiphertext, - ) -> Self { - Self { - our_identity_key_pair, - our_signed_pre_key_pair, - our_one_time_pre_key_pair, - our_ratchet_key_pair, - our_kyber_pre_key_pair, - their_identity_key, - their_base_key, - their_kyber_ciphertext, - } - } - - #[inline] - pub fn our_identity_key_pair(&self) -> &IdentityKeyPair { - &self.our_identity_key_pair - } - - #[inline] - pub fn our_signed_pre_key_pair(&self) -> &KeyPair { - &self.our_signed_pre_key_pair - } - - #[inline] - pub fn our_one_time_pre_key_pair(&self) -> Option<&KeyPair> { - self.our_one_time_pre_key_pair.as_ref() - } - - #[inline] - pub fn our_ratchet_key_pair(&self) -> &KeyPair { - &self.our_ratchet_key_pair - } - - #[inline] - pub fn our_kyber_pre_key_pair(&self) -> &kem::KeyPair { - &self.our_kyber_pre_key_pair - } - - #[inline] - pub fn their_identity_key(&self) -> &IdentityKey { - &self.their_identity_key - } - - #[inline] - pub fn their_base_key(&self) -> &PublicKey { - &self.their_base_key - } - - #[inline] - pub fn their_kyber_ciphertext(&self) -> &kem::SerializedCiphertext { - self.their_kyber_ciphertext - } -} diff --git a/rust/protocol/src/sealed_sender.rs b/rust/protocol/src/sealed_sender.rs index c2d44b587..21d77dc54 100644 --- a/rust/protocol/src/sealed_sender.rs +++ b/rust/protocol/src/sealed_sender.rs @@ -23,7 +23,7 @@ use crate::{ IdentityKeyStore, KeyPair, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, PrivateKey, ProtocolAddress, PublicKey, Result, ServiceId, ServiceIdFixedWidthBinaryBytes, SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore, Timestamp, crypto, - message_encrypt, proto, session_cipher, + message_encrypt, proto, session_management, }; #[derive(Debug, Clone)] @@ -2044,7 +2044,7 @@ pub async fn sealed_sender_decrypt( let message = match usmc.msg_type()? { CiphertextMessageType::Whisper => { let ctext = SignalMessage::try_from(usmc.contents()?)?; - session_cipher::message_decrypt_signal( + session_management::message_decrypt_signal( &ctext, &remote_address, session_store, @@ -2055,7 +2055,7 @@ pub async fn sealed_sender_decrypt( } CiphertextMessageType::PreKey => { let ctext = PreKeySignalMessage::try_from(usmc.contents()?)?; - session_cipher::message_decrypt_prekey( + session_management::message_decrypt_prekey( &ctext, &remote_address, &local_address, diff --git a/rust/protocol/src/session.rs b/rust/protocol/src/session.rs index e05693acd..b91e19ac2 100644 --- a/rust/protocol/src/session.rs +++ b/rust/protocol/src/session.rs @@ -145,16 +145,16 @@ async fn process_prekey_impl( let parameters = BobSignalProtocolParameters::new( identity_store.get_identity_key_pair().await?, - our_signed_pre_key_pair, // signed pre key + our_signed_pre_key_pair, our_one_time_pre_key_pair, - our_signed_pre_key_pair, // ratchet key our_kyber_pre_key_pair, *message.identity_key(), *message.base_key(), kyber_ciphertext, ); - let mut new_session = ratchet::initialize_bob_session(¶meters)?; + // The recipient's initial ratchet key is the signed pre-key. + let mut new_session = ratchet::initialize_bob_session(¶meters, &our_signed_pre_key_pair)?; new_session.set_local_registration_id(identity_store.get_local_registration_id().await?); new_session.set_remote_registration_id(message.registration_id()); diff --git a/rust/protocol/src/session_cipher.rs b/rust/protocol/src/session_cipher_legacy.rs similarity index 96% rename from rust/protocol/src/session_cipher.rs rename to rust/protocol/src/session_cipher_legacy.rs index d43a98d7e..04f6f8993 100644 --- a/rust/protocol/src/session_cipher.rs +++ b/rust/protocol/src/session_cipher_legacy.rs @@ -1,7 +1,16 @@ // -// Copyright 2020-2022 Signal Messenger, LLC. +// Copyright 2020-2026 Signal Messenger, LLC. // SPDX-License-Identifier: AGPL-3.0-only // +// Frozen snapshot of session_cipher.rs taken before the protocol refactoring. +// Used exclusively for interoperability tests: the legacy encrypt/decrypt paths +// must be able to exchange messages with the new implementation. +// +// DO NOT modify the crypto logic in this file. The point is to have an +// immutable reference implementation to test against. + +#![cfg(test)] +#![allow(dead_code)] use std::time::SystemTime; @@ -16,7 +25,7 @@ use crate::{ SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore, session, }; -pub async fn message_encrypt( +pub async fn legacy_message_encrypt( ptext: &[u8], remote_address: &ProtocolAddress, local_address: &ProtocolAddress, @@ -162,7 +171,7 @@ pub async fn message_encrypt( } #[allow(clippy::too_many_arguments)] -pub async fn message_decrypt( +pub async fn legacy_message_decrypt( ciphertext: &CiphertextMessage, remote_address: &ProtocolAddress, local_address: &ProtocolAddress, @@ -175,11 +184,11 @@ pub async fn message_decrypt( ) -> Result> { match ciphertext { CiphertextMessage::SignalMessage(m) => { - let _ = local_address; - message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await + legacy_message_decrypt_signal(m, remote_address, session_store, identity_store, csprng) + .await } CiphertextMessage::PreKeySignalMessage(m) => { - message_decrypt_prekey( + legacy_message_decrypt_prekey( m, remote_address, local_address, @@ -200,7 +209,7 @@ pub async fn message_decrypt( } #[allow(clippy::too_many_arguments)] -pub async fn message_decrypt_prekey( +pub async fn legacy_message_decrypt_prekey( ciphertext: &PreKeySignalMessage, remote_address: &ProtocolAddress, local_address: &ProtocolAddress, @@ -285,7 +294,7 @@ pub async fn message_decrypt_prekey( Ok(ptext) } -pub async fn message_decrypt_signal( +pub async fn legacy_message_decrypt_signal( ciphertext: &SignalMessage, remote_address: &ProtocolAddress, session_store: &mut dyn SessionStore, diff --git a/rust/protocol/src/session_management.rs b/rust/protocol/src/session_management.rs new file mode 100644 index 000000000..119780443 --- /dev/null +++ b/rust/protocol/src/session_management.rs @@ -0,0 +1,3113 @@ +// +// Copyright 2026 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +//! Session management and public encrypt/decrypt API for Signal 1:1 messaging. +//! +//! This module owns two things: +//! +//! 1. **The public API** — [`message_encrypt`], [`message_decrypt`], +//! [`message_decrypt_signal`], [`message_decrypt_prekey`]. These are the +//! entry points used by the bridge layer and `sealed_sender`. +//! +//! 2. **Sesame session management** — the "which session do we use?" logic: +//! trial-decryption across current and previous sessions, session promotion +//! on success, and session selection for encryption. +//! +//! All cryptographic ratchet operations are delegated to +//! [`TripleRatchet`]. This module has no knowledge of chain keys, +//! root keys, or SPQR internals. + +use std::time::SystemTime; + +use displaydoc::Display; +use rand::{CryptoRng, Rng}; + +use crate::consts::MAX_UNACKNOWLEDGED_SESSION_AGE; +use crate::state::{InvalidSessionError, SessionState}; +use crate::triple_ratchet::{OutgoingTripleRatchet, TripleRatchet}; +use crate::{ + CiphertextMessage, CiphertextMessageType, Direction, IdentityKeyStore, KyberPayload, + KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, Result, SessionRecord, + SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore, session, +}; + +// ── Public API ─────────────────────────────────────────────────────────────── + +/// Encrypt `ptext` for `remote_address`, loading and storing session state. +/// +/// If the session is unacknowledged (a locally-initiated session that has not +/// yet received a response), wraps the [`SignalMessage`] in a +/// [`PreKeySignalMessage`] containing the original pre-key material. +pub async fn message_encrypt( + ptext: &[u8], + remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, + session_store: &mut dyn SessionStore, + identity_store: &mut dyn IdentityKeyStore, + now: SystemTime, + csprng: &mut R, +) -> Result { + let mut session_record = session_store + .load_session(remote_address) + .await? + .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?; + let session_state = session_record + .session_state_mut() + .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?; + + let mut session = OutgoingTripleRatchet::from_session_state(session_state).map_err(|e| { + log::error!("session state corrupt for {remote_address}: {e}"); + e + })?; + + let their_identity_key = session_state + .remote_identity_key()? + .expect("session was valid; must have remote identity key"); + + // Pre-key wrapping — session management concern. + let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? { + let timestamp_as_unix_time = items + .timestamp() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if items.timestamp() + MAX_UNACKNOWLEDGED_SESSION_AGE < now { + log::warn!( + "stale unacknowledged session for {remote_address} (created at {timestamp_as_unix_time})" + ); + return Err(SignalProtocolError::SessionNotFound(remote_address.clone())); + } + + let local_registration_id = session_state.local_registration_id(); + + log::info!( + "Building PreKeyWhisperMessage for: {} with preKeyId: {} (session created at {})", + remote_address, + items + .pre_key_id() + .map_or_else(|| "".to_string(), |id| id.to_string()), + timestamp_as_unix_time, + ); + + let kyber_payload = items + .kyber_pre_key_id() + .zip(items.kyber_ciphertext()) + .map(|(id, ciphertext)| KyberPayload::new(id, ciphertext.into())); + let signal_message = session.encrypt(ptext, Some(local_address), remote_address, csprng)?; + + CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new( + session.session_version(), + local_registration_id, + items.pre_key_id(), + items.signed_pre_key_id(), + kyber_payload, + *items.base_key(), + *session.local_identity_key(), + signal_message, + )?) + } else { + let signal_message = session.encrypt(ptext, None, remote_address, csprng)?; + CiphertextMessage::SignalMessage(signal_message) + }; + + // In clients, `is_trusted_identity` for the Sending direction checks + // whether the session's identity key matches the stored key AND whether the + // user has approved it (safety number changes, verification status). This + // prevents sending to a contact whose identity has changed without user + // acknowledgment. + if !identity_store + .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending) + .await? + { + log::warn!( + "Identity key {} is not trusted for remote address {}", + hex::encode(their_identity_key.public_key().public_key_bytes()), + remote_address, + ); + return Err(SignalProtocolError::UntrustedIdentity( + remote_address.clone(), + )); + } + + identity_store + .save_identity(remote_address, &their_identity_key) + .await?; + + // Commit and save session state changes. + session.apply_to_session_state(session_state); + + session_store + .store_session(remote_address, &session_record) + .await?; + Ok(message) +} + +/// Decrypt a [`CiphertextMessage`] from `remote_address`. +/// +/// Routes to [`message_decrypt_signal`] or [`message_decrypt_prekey`] based +/// on message type. +#[allow(clippy::too_many_arguments)] +pub async fn message_decrypt( + ciphertext: &CiphertextMessage, + remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, + session_store: &mut dyn SessionStore, + identity_store: &mut dyn IdentityKeyStore, + pre_key_store: &mut dyn PreKeyStore, + signed_pre_key_store: &dyn SignedPreKeyStore, + kyber_pre_key_store: &mut dyn KyberPreKeyStore, + csprng: &mut R, +) -> Result> { + match ciphertext { + CiphertextMessage::SignalMessage(m) => { + message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await + } + CiphertextMessage::PreKeySignalMessage(m) => { + message_decrypt_prekey( + m, + remote_address, + local_address, + session_store, + identity_store, + pre_key_store, + signed_pre_key_store, + kyber_pre_key_store, + csprng, + ) + .await + } + _ => Err(SignalProtocolError::InvalidArgument(format!( + "message_decrypt cannot be used to decrypt {:?} messages", + ciphertext.message_type() + ))), + } +} + +/// Decrypt a [`PreKeySignalMessage`] from `remote_address`. +/// +/// Processes the pre-key material to establish a session (via +/// [`session::process_prekey`]), then decrypts the inner [`SignalMessage`]. +#[allow(clippy::too_many_arguments)] +pub async fn message_decrypt_prekey( + ciphertext: &PreKeySignalMessage, + remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, + session_store: &mut dyn SessionStore, + identity_store: &mut dyn IdentityKeyStore, + pre_key_store: &mut dyn PreKeyStore, + signed_pre_key_store: &dyn SignedPreKeyStore, + kyber_pre_key_store: &mut dyn KyberPreKeyStore, + csprng: &mut R, +) -> Result> { + let mut session_record = session_store + .load_session(remote_address) + .await? + .unwrap_or_else(SessionRecord::new_fresh); + + // Make sure we log the session state if we fail to process the pre-key. + let process_prekey_result = session::process_prekey( + ciphertext, + remote_address, + &mut session_record, + identity_store, + pre_key_store, + signed_pre_key_store, + kyber_pre_key_store, + ) + .await; + + let (pre_key_used, identity_to_save) = match process_prekey_result { + Ok(result) => result, + Err(e) => { + let errs = [e]; + log::error!( + "{}", + format_decryption_failure_log( + remote_address, + &errs, + &session_record, + ciphertext.message() + )? + ); + let [e] = errs; + return Err(e); + } + }; + + let ptext = try_decrypt_from_record( + &mut session_record, + remote_address, + Some(local_address), + ciphertext.message(), + CiphertextMessageType::PreKey, + csprng, + )?; + + identity_store + .save_identity( + identity_to_save.remote_address, + identity_to_save.their_identity_key, + ) + .await?; + + if let Some(pre_key_used) = pre_key_used { + if let Some(kyber_pre_key_id) = pre_key_used.kyber_pre_key_id { + kyber_pre_key_store + .mark_kyber_pre_key_used( + kyber_pre_key_id, + pre_key_used.signed_ec_pre_key_id, + ciphertext.base_key(), + ) + .await?; + } + + if let Some(pre_key_id) = pre_key_used.one_time_ec_pre_key_id { + pre_key_store.remove_pre_key(pre_key_id).await?; + } + } + + session_store + .store_session(remote_address, &session_record) + .await?; + + Ok(ptext) +} + +/// Decrypt a [`SignalMessage`] from `remote_address`. +/// +/// Tries all sessions in the session record. Checks identity key trust +/// after decryption. +pub async fn message_decrypt_signal( + ciphertext: &SignalMessage, + remote_address: &ProtocolAddress, + session_store: &mut dyn SessionStore, + identity_store: &mut dyn IdentityKeyStore, + csprng: &mut R, +) -> Result> { + let mut session_record = session_store + .load_session(remote_address) + .await? + .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?; + + let ptext = try_decrypt_from_record( + &mut session_record, + remote_address, + None, + ciphertext, + CiphertextMessageType::Whisper, + csprng, + )?; + + // Why are we performing this check after decryption instead of before? + let their_identity_key = session_record + .session_state() + .expect("successfully decrypted; must have a current state") + .remote_identity_key() + .expect("successfully decrypted; must have a remote identity key") + .expect("successfully decrypted; must have a remote identity key"); + + if !identity_store + .is_trusted_identity(remote_address, &their_identity_key, Direction::Receiving) + .await? + { + log::warn!( + "Identity key {} is not trusted for remote address {}", + hex::encode(their_identity_key.public_key().public_key_bytes()), + remote_address, + ); + return Err(SignalProtocolError::UntrustedIdentity( + remote_address.clone(), + )); + } + + identity_store + .save_identity(remote_address, &their_identity_key) + .await?; + + session_store + .store_session(remote_address, &session_record) + .await?; + + Ok(ptext) +} + +// ── Session management (Sesame) ────────────────────────────────────────────── + +/// Try to decrypt `ciphertext` against every session in `record`, in order. +/// +/// Tries the current session first, then previous sessions. On success from +/// a previous session, promotes that session to current (Sesame behavior). +/// +/// `original_message_type` is `Whisper` for normal messages and `PreKey` for +/// the inner `SignalMessage` of a pre-key message. When it is `PreKey`, we +/// skip the fallback to previous sessions — a PreKey message establishes a +/// fresh session and should always match the current one. +pub(crate) fn try_decrypt_from_record( + record: &mut SessionRecord, + remote_address: &ProtocolAddress, + local_address: Option<&ProtocolAddress>, + ciphertext: &SignalMessage, + original_message_type: CiphertextMessageType, + csprng: &mut R, +) -> Result> { + debug_assert!(matches!( + original_message_type, + CiphertextMessageType::Whisper | CiphertextMessageType::PreKey + )); + let ciphertext_version = ciphertext.message_version() as u32; + + let log_failure = |label: &str, state: &SessionState, error: &SignalProtocolError| { + log::warn!( + "Failed to decrypt {:?} message with ratchet key: {} and counter: {}. \ + Session loaded for {}. {} session has base key: {} and counter: {}. {}", + original_message_type, + hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()), + ciphertext.counter(), + remote_address, + label, + state + .sender_ratchet_key_for_logging() + .unwrap_or_else(|e| format!("")), + state.previous_counter(), + error + ); + }; + + let mut errs = vec![]; + + // ── Try current session ────────────────────────────────────────────────── + + if let Some(current_state) = record.session_state() { + let mut current_state = current_state.clone(); + + if current_state.session_version()? != ciphertext_version { + let e = SignalProtocolError::UnrecognizedMessageVersion(ciphertext_version); + log_failure("Current", ¤t_state, &e); + errs.push(e); + } else { + match try_decrypt_with_state( + &mut current_state, + remote_address, + local_address, + ciphertext, + original_message_type, + CurrentOrPrevious::Current, + csprng, + ) { + Ok(ptext) => { + log::info!( + "decrypted {:?} message from {} with current session state (base key {})", + original_message_type, + remote_address, + current_state + .sender_ratchet_key_for_logging() + .expect("successful decrypt always has a valid base key"), + ); + record.set_session_state(current_state); + return Ok(ptext); + } + Err(e @ SignalProtocolError::DuplicatedMessage(_, _)) => return Err(e), + Err(e) => { + log_failure("Current", ¤t_state, &e); + errs.push(e); + match original_message_type { + CiphertextMessageType::PreKey => { + // A PreKey message creates a session and then decrypts a Whisper message + // using that session. No need to check older sessions. + log::error!( + "{}", + format_decryption_failure_log( + remote_address, + &errs, + record, + ciphertext, + )? + ); + // Note that we don't propagate `e` here; we always return InvalidMessage, + // as we would for a Whisper message that tried several sessions. + return Err(SignalProtocolError::InvalidMessage( + original_message_type, + "decryption failed", + )); + } + CiphertextMessageType::Whisper => {} + CiphertextMessageType::SenderKey | CiphertextMessageType::Plaintext => { + unreachable!("should not be using Double Ratchet for these") + } + } + } + } + } + } + + // ── Try previous sessions (Whisper only) ───────────────────────────────── + + let mut promoted = None; + + for (idx, previous) in record.previous_session_states().enumerate() { + let mut previous = match previous { + Ok(previous) => previous, + Err(e) => { + let e: SignalProtocolError = e.into(); + log::warn!( + "Skipping corrupt previous session {} for {}: {}", + idx, + remote_address, + e + ); + errs.push(e); + continue; + } + }; + + if previous.session_version()? != ciphertext_version { + let e = SignalProtocolError::UnrecognizedMessageVersion(ciphertext_version); + log_failure("Previous", &previous, &e); + errs.push(e); + continue; + } + + match try_decrypt_with_state( + &mut previous, + remote_address, + local_address, + ciphertext, + original_message_type, + CurrentOrPrevious::Previous, + csprng, + ) { + Ok(ptext) => { + log::info!( + "decrypted {:?} message from {} with PREVIOUS session state (base key {})", + original_message_type, + remote_address, + previous + .sender_ratchet_key_for_logging() + .expect("successful decrypt always has a valid base key"), + ); + promoted = Some((ptext, idx, previous)); + break; + } + Err(e @ SignalProtocolError::DuplicatedMessage(_, _)) => return Err(e), + Err(e) => { + log_failure("Previous", &previous, &e); + errs.push(e); + } + } + } + + if let Some((ptext, idx, updated)) = promoted { + // Sesame: promote the successful previous session to current. + // The upcoming session management update will remove this promotion. + record.promote_old_session(idx, updated); + Ok(ptext) + } else { + let previous_state_count = || record.previous_session_states().len(); + if let Some(current_state) = record.session_state() { + log::error!( + "No valid session for recipient: {}, current session base key {}, \ + number of previous states: {}", + remote_address, + current_state + .sender_ratchet_key_for_logging() + .unwrap_or_else(|e| format!("")), + previous_state_count(), + ); + } else { + log::error!( + "No valid session for recipient: {}, (no current session state), \ + number of previous states: {}", + remote_address, + previous_state_count(), + ); + } + log::error!( + "{}", + format_decryption_failure_log(remote_address, &errs, record, ciphertext)? + ); + Err(SignalProtocolError::InvalidMessage( + original_message_type, + "decryption failed", + )) + } +} + +// ── Per-session decrypt ────────────────────────────────────────────────────── + +/// Attempt to decrypt `ciphertext` using the crypto state in `state`. +/// +/// Caller must only pass version-compatible ciphertext/session pairs. +/// +/// Constructs a [`TripleRatchet`], delegates the actual decryption, and writes +/// updated state back on success. On failure, `state` is unchanged. +pub(crate) fn try_decrypt_with_state( + state: &mut SessionState, + remote_address: &ProtocolAddress, + local_address: Option<&ProtocolAddress>, + ciphertext: &SignalMessage, + original_message_type: CiphertextMessageType, + curr_or_prev_for_logging: CurrentOrPrevious, + csprng: &mut R, +) -> Result> { + debug_assert_eq!( + state.session_version()?, + ciphertext.message_version() as u32 + ); + + let mut session = TripleRatchet::from_session_state(state)?; + + let ptext = session.decrypt( + remote_address, + local_address, + ciphertext, + original_message_type, + curr_or_prev_for_logging, + csprng, + )?; + + session.apply_to_session_state(state); + state.clear_unacknowledged_pre_key_message(); + + Ok(ptext) +} + +// ── Logging helpers ────────────────────────────────────────────────────────── + +pub(crate) fn format_decryption_failure_log( + remote_address_for_logging: &ProtocolAddress, + mut errs: &[SignalProtocolError], + record: &SessionRecord, + ciphertext: &SignalMessage, +) -> Result { + fn append_session_summary( + lines: &mut Vec, + idx: usize, + state: std::result::Result<&SessionState, InvalidSessionError>, + err: Option<&SignalProtocolError>, + ) { + let chains = state.map(|state| state.all_receiver_chain_logging_info()); + match (err, &chains) { + (Some(err), Ok(chains)) => { + lines.push(format!( + "Candidate session {} failed with '{}', had {} receiver chains", + idx, + err, + chains.len() + )); + } + (Some(err), Err(state_err)) => { + lines.push(format!( + "Candidate session {idx} failed with '{err}'; \ + cannot get receiver chain info ({state_err})", + )); + } + (None, Ok(chains)) => { + lines.push(format!( + "Candidate session {} had {} receiver chains", + idx, + chains.len() + )); + } + (None, Err(state_err)) => { + lines.push(format!( + "Candidate session {idx}: cannot get receiver chain info ({state_err})", + )); + } + } + + if let Ok(chains) = chains { + for chain in chains { + let chain_idx = match chain.1 { + Some(i) => i.to_string(), + None => "missing in protobuf".to_string(), + }; + lines.push(format!( + "Receiver chain with sender ratchet public key {} chain key index {}", + hex::encode(chain.0), + chain_idx + )); + } + } + } + + let mut lines = vec![]; + lines.push(format!( + "Message from {} failed to decrypt; sender ratchet public key {} message counter {}", + remote_address_for_logging, + hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()), + ciphertext.counter() + )); + + if let Some(current_session) = record.session_state() { + let err = errs.first(); + if err.is_some() { + errs = &errs[1..]; + } + append_session_summary(&mut lines, 0, Ok(current_session), err); + } else { + lines.push("No current session".to_string()); + } + + for (idx, (state, err)) in record + .previous_session_states() + .zip(errs.iter().map(Some).chain(std::iter::repeat(None))) + .enumerate() + { + let state = match state { + Ok(ref state) => Ok(state), + Err(err) => Err(err), + }; + append_session_summary(&mut lines, idx + 1, state, err); + } + + Ok(lines.join("\n")) +} + +#[derive(Clone, Copy, Display)] +pub(crate) enum CurrentOrPrevious { + /// current + Current, + /// previous + Previous, +} + +// ── Comparison proptest ────────────────────────────────────────────────────── +// +// Verifies that the refactored encrypt/decrypt path produces identical results +// to the legacy snapshot for any message sequence. +#[cfg(test)] +mod legacy_interop_tests { + // These tests live next to `session_management` rather than under + // `rust/protocol/tests/` because they compare the refactored code against + // the private `session_cipher_legacy` implementation and also assert + // byte-level equivalence of internal persisted state. That makes them + // implementation-regression tests for this refactor, not normal public API + // integration tests. + // + // This harness is temporary. Once we are confident in the refactor, remove + // `session_cipher_legacy` and the new-vs-legacy equivalence tests along + // with it. + use futures_util::FutureExt; + use libsignal_protocol_test_support::Event; + use proptest::prelude::*; + use prost::Message; + use rand::SeedableRng; + use rand_chacha::ChaCha8Rng; + + use super::*; + use crate::proto::storage::RecordStructure; + use crate::ratchet::{ + AliceSignalProtocolParameters, BobSignalProtocolParameters, + initialize_alice_session_record, initialize_bob_session_record, + }; + use crate::{ + DecryptionErrorMessage, DeviceId, GenericSignedPreKey, IdentityKeyPair, + InMemSignalProtocolStore, KeyPair, KyberPreKeyId, KyberPreKeyRecord, PlaintextContent, + PreKeyBundle, PreKeyId, PreKeyRecord, ProtocolAddress, SessionRecord, + SessionUsabilityRequirements, SignalProtocolError, SignedPreKeyId, SignedPreKeyRecord, + Timestamp, extract_decryption_error_message_from_serialized_content, process_prekey_bundle, + session_cipher_legacy as legacy, + }; + + #[derive(Clone, Copy, PartialEq, Eq, Debug)] + enum MessageStatus { + Sent, + Dropped, + Delivered, + } + + #[derive(Clone)] + struct DualLocalState { + new_store: InMemSignalProtocolStore, + legacy_store: InMemSignalProtocolStore, + pre_key_count: u32, + } + + struct DualParticipant { + address: ProtocolAddress, + message_queue: Vec<(CiphertextMessage, u64)>, + state: DualLocalState, + snapshots: Vec, + message_send_log: Vec, + } + + /// Build a matched (alice, bob) session pair from a seeded RNG. + fn setup_stores( + rng: &mut ChaCha8Rng, + ) -> ( + InMemSignalProtocolStore, + InMemSignalProtocolStore, + ProtocolAddress, + ProtocolAddress, + ) { + let alice_identity = IdentityKeyPair::generate(rng); + let bob_identity = IdentityKeyPair::generate(rng); + + let alice_base_key = KeyPair::generate(rng); + let bob_signed_pre_key = KeyPair::generate(rng); + let bob_kyber_key = crate::kem::KeyPair::generate(crate::kem::KeyType::Kyber1024, rng); + + let alice_params = AliceSignalProtocolParameters::new( + alice_identity, + alice_base_key, + *bob_identity.identity_key(), + bob_signed_pre_key.public_key, + bob_signed_pre_key.public_key, + bob_kyber_key.public_key.clone(), + ); + + let alice_record = + initialize_alice_session_record(&alice_params, rng).expect("alice session init"); + let kyber_ct: Box<[u8]> = alice_record + .get_kyber_ciphertext() + .expect("session valid") + .expect("has kyber ciphertext") + .clone() + .into_boxed_slice(); + + let bob_params = BobSignalProtocolParameters::new( + bob_identity, + bob_signed_pre_key, + None, + bob_kyber_key, + *alice_identity.identity_key(), + alice_base_key.public_key, + &kyber_ct, + ); + + let bob_record = initialize_bob_session_record(&bob_params, &bob_signed_pre_key) + .expect("bob session init"); + + let alice_address = ProtocolAddress::new( + "57721566-4901-5328-6060-651209008240".to_owned(), + DeviceId::new(1).unwrap(), + ); + let bob_address = ProtocolAddress::new( + "26149721-2847-6427-8375-542683860869".to_owned(), + DeviceId::new(1).unwrap(), + ); + + let mut alice_store = InMemSignalProtocolStore::new(alice_identity, 1).unwrap(); + let mut bob_store = InMemSignalProtocolStore::new(bob_identity, 2).unwrap(); + + alice_store + .session_store + .store_session(&bob_address, &alice_record) + .now_or_never() + .unwrap() + .unwrap(); + bob_store + .session_store + .store_session(&alice_address, &bob_record) + .now_or_never() + .unwrap() + .unwrap(); + + (alice_store, bob_store, alice_address, bob_address) + } + + /// Create a pre-key bundle for Bob, storing new key material in his store. + /// + /// The `*_id` parameters must not collide with any IDs already in the + /// store. Using a monotonically increasing generation counter (1, 2, …) + /// is sufficient. + fn create_bob_bundle( + bob_store: &mut InMemSignalProtocolStore, + pre_key_id: u32, + signed_pre_key_id: u32, + kyber_pre_key_id: u32, + rng: &mut ChaCha8Rng, + ) -> PreKeyBundle { + let identity_key_pair = bob_store + .get_identity_key_pair() + .now_or_never() + .unwrap() + .unwrap(); + + let pre_key = KeyPair::generate(rng); + let signed_pre_key = KeyPair::generate(rng); + let kyber_key = crate::kem::KeyPair::generate(crate::kem::KeyType::Kyber1024, rng); + + let pk_id = PreKeyId::from(pre_key_id); + let spk_id = SignedPreKeyId::from(signed_pre_key_id); + let kpk_id = KyberPreKeyId::from(kyber_pre_key_id); + + let spk_sig = identity_key_pair + .private_key() + .calculate_signature(&signed_pre_key.public_key.serialize(), rng) + .unwrap(); + let kpk_sig = identity_key_pair + .private_key() + .calculate_signature(&kyber_key.public_key.serialize(), rng) + .unwrap(); + + bob_store + .save_pre_key(pk_id, &PreKeyRecord::new(pk_id, &pre_key)) + .now_or_never() + .unwrap() + .unwrap(); + bob_store + .save_signed_pre_key( + spk_id, + &SignedPreKeyRecord::new( + spk_id, + Timestamp::from_epoch_millis(42), + &signed_pre_key, + &spk_sig, + ), + ) + .now_or_never() + .unwrap() + .unwrap(); + bob_store + .save_kyber_pre_key( + kpk_id, + &KyberPreKeyRecord::new( + kpk_id, + Timestamp::from_epoch_millis(43), + &kyber_key, + &kpk_sig, + ), + ) + .now_or_never() + .unwrap() + .unwrap(); + + let reg_id = bob_store + .get_local_registration_id() + .now_or_never() + .unwrap() + .unwrap(); + + PreKeyBundle::new( + reg_id, + DeviceId::new(1).unwrap(), + Some((pk_id, pre_key.public_key)), + spk_id, + signed_pre_key.public_key, + spk_sig.to_vec(), + kpk_id, + kyber_key.public_key.clone(), + kpk_sig.to_vec(), + *identity_key_pair.identity_key(), + ) + .unwrap() + } + + #[test] + fn encrypt_preserves_corruption_error_instead_of_session_not_found() { + let mut rng = ChaCha8Rng::seed_from_u64(0xC0FFEE); + let (mut alice_store, _bob_store, alice_address, bob_address) = setup_stores(&mut rng); + let now = SystemTime::now(); + + let good_record = alice_store + .session_store + .load_session(&bob_address) + .now_or_never() + .expect("sync") + .expect("load succeeded") + .expect("session exists"); + + let serialized = good_record.serialize().expect("serialize"); + let mut record_pb = RecordStructure::decode(serialized.as_slice()).expect("decode record"); + record_pb + .current_session + .as_mut() + .expect("current session") + .remote_identity_public = vec![0xFF]; + + let corrupted_record = SessionRecord::deserialize(record_pb.encode_to_vec().as_slice()) + .expect("deserialize corrupted record"); + + alice_store + .session_store + .store_session(&bob_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let legacy_err = legacy::legacy_message_encrypt( + b"test", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect_err("legacy encrypt should fail on corrupted state"); + assert!( + matches!( + legacy_err, + SignalProtocolError::InvalidSessionStructure("invalid remote identity key") + ), + "unexpected legacy error: {legacy_err:?}" + ); + + alice_store + .session_store + .store_session(&bob_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let new_err = message_encrypt( + b"test", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect_err("new encrypt should fail on corrupted state"); + assert!( + matches!( + new_err, + SignalProtocolError::InvalidSessionStructure("invalid remote identity key") + ), + "unexpected new error: {new_err:?}" + ); + } + + #[test] + fn encrypt_ignores_corrupt_unused_receiver_chain() { + let mut rng = ChaCha8Rng::seed_from_u64(0xACE55); + let (mut alice_store, _bob_store, alice_address, bob_address) = setup_stores(&mut rng); + let now = SystemTime::now(); + + let good_record = alice_store + .session_store + .load_session(&bob_address) + .now_or_never() + .expect("sync") + .expect("load succeeded") + .expect("session exists"); + + let serialized = good_record.serialize().expect("serialize"); + let mut record_pb = RecordStructure::decode(serialized.as_slice()).expect("decode record"); + let current_session = record_pb.current_session.as_mut().expect("current session"); + assert!( + !current_session.receiver_chains.is_empty(), + "expected at least one receiver chain" + ); + current_session.receiver_chains[0].sender_ratchet_key = vec![0xFF]; + + let corrupted_record = SessionRecord::deserialize(record_pb.encode_to_vec().as_slice()) + .expect("deserialize corrupted record"); + + alice_store + .session_store + .store_session(&bob_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let mut legacy_rng = rng.clone(); + let legacy_ct = legacy::legacy_message_encrypt( + b"test", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut legacy_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy encrypt should ignore unused receiver-chain corruption"); + + alice_store + .session_store + .store_session(&bob_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let new_ct = message_encrypt( + b"test", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new encrypt should ignore unused receiver-chain corruption"); + + let legacy_msg = match legacy_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage from legacy enc, got {:?}", + other.message_type() + ), + }; + let new_msg = match new_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage from new enc, got {:?}", + other.message_type() + ), + }; + + assert_eq!(legacy_msg.serialized(), new_msg.serialized()); + } + + #[test] + fn decrypt_skips_corrupt_previous_session_and_uses_later_valid_previous() { + let mut rng = ChaCha8Rng::seed_from_u64(0xBAD5EED); + let (mut alice_store, mut bob_store, alice_address, bob_address) = setup_stores(&mut rng); + let now = SystemTime::now(); + + let delayed_plaintext = b"delayed on session A".to_vec(); + + let delayed_ct = legacy::legacy_message_encrypt( + &delayed_plaintext, + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("delayed legacy enc"); + + let delayed_signal_msg = match delayed_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage for delayed msg, got {:?}", + other.message_type() + ), + }; + + let bundle = create_bob_bundle(&mut bob_store, 1, 1, 1, &mut rng); + process_prekey_bundle( + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &bundle, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("process_prekey_bundle"); + + let session_b_init = message_encrypt( + b"session B init", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B init enc"); + + message_decrypt( + &session_b_init, + &alice_address, + &bob_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut bob_store.pre_key_store, + &bob_store.signed_pre_key_store, + &mut bob_store.kyber_pre_key_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B init dec"); + + let session_b_ack = message_encrypt( + b"session B ack", + &alice_address, + &bob_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B ack enc"); + + let session_b_ack_signal = match &session_b_ack { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected Whisper for session B ack, got {:?}", + other.message_type() + ), + }; + message_decrypt_signal( + session_b_ack_signal, + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B ack dec"); + + let bob_record = bob_store + .session_store + .load_session(&alice_address) + .now_or_never() + .expect("sync") + .expect("load succeeded") + .expect("session exists"); + let serialized = bob_record.serialize().expect("serialize"); + let mut record_pb = RecordStructure::decode(serialized.as_slice()).expect("decode record"); + assert_eq!( + record_pb.previous_sessions.len(), + 1, + "expected one valid previous session" + ); + record_pb.previous_sessions.insert(0, vec![0xFF]); + let corrupted_record = SessionRecord::deserialize(record_pb.encode_to_vec().as_slice()) + .expect("deserialize mutated record"); + bob_store + .session_store + .store_session(&alice_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let ptext = message_decrypt_signal( + &delayed_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("delayed msg dec via valid later previous session"); + + assert_eq!(ptext, delayed_plaintext); + } + + fn setup_two_alice_receiver_chains_on_bob( + rng: &mut ChaCha8Rng, + ) -> ( + InMemSignalProtocolStore, + InMemSignalProtocolStore, + ProtocolAddress, + ProtocolAddress, + SignalMessage, + ) { + let (mut alice_store, mut bob_store, alice_address, bob_address) = setup_stores(rng); + let now = SystemTime::now(); + + let delayed_ct = message_encrypt( + b"delayed old", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("delayed old enc"); + let delayed_signal_msg = match delayed_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected delayed SignalMessage, got {:?}", + other.message_type() + ), + }; + + let trigger_ct = message_encrypt( + b"trigger old chain advancement", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("trigger enc"); + let trigger_signal_msg = match trigger_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected trigger SignalMessage, got {:?}", + other.message_type() + ), + }; + message_decrypt_signal( + &trigger_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect("trigger dec"); + + let bob_reply_ct = message_encrypt( + b"bob reply new ratchet", + &alice_address, + &bob_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("bob reply enc"); + let bob_reply_signal_msg = match bob_reply_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected bob reply SignalMessage, got {:?}", + other.message_type() + ), + }; + message_decrypt_signal( + &bob_reply_signal_msg, + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect("bob reply dec"); + + let alice_new_chain_ct = message_encrypt( + b"alice new chain", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("alice new chain enc"); + let alice_new_chain_signal_msg = match alice_new_chain_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected alice new chain SignalMessage, got {:?}", + other.message_type() + ), + }; + message_decrypt_signal( + &alice_new_chain_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect("alice new chain dec"); + + ( + alice_store, + bob_store, + alice_address, + bob_address, + delayed_signal_msg, + ) + } + + #[test] + fn decrypt_ignores_corrupt_unmatched_receiver_chain() { + let mut rng = ChaCha8Rng::seed_from_u64(0xD311A9); + let (_alice_store, mut bob_store, alice_address, _bob_address, delayed_signal_msg) = + setup_two_alice_receiver_chains_on_bob(&mut rng); + + let bob_record = bob_store + .session_store + .load_session(&alice_address) + .now_or_never() + .expect("sync") + .expect("load succeeded") + .expect("session exists"); + let serialized = bob_record.serialize().expect("serialize"); + let mut record_pb = RecordStructure::decode(serialized.as_slice()).expect("decode record"); + let current_session = record_pb.current_session.as_mut().expect("current session"); + assert!( + current_session.receiver_chains.len() >= 2, + "expected at least two receiver chains" + ); + + let matched_key = delayed_signal_msg.sender_ratchet_key().serialize().to_vec(); + let matched_idx = current_session + .receiver_chains + .iter() + .position(|chain| chain.sender_ratchet_key == matched_key) + .expect("matching receiver chain present"); + let unmatched_idx = (0..current_session.receiver_chains.len()) + .find(|idx| *idx != matched_idx) + .expect("unmatched receiver chain present"); + + current_session.receiver_chains[unmatched_idx].sender_ratchet_key = vec![0xFF]; + + let corrupted_record = SessionRecord::deserialize(record_pb.encode_to_vec().as_slice()) + .expect("deserialize mutated record"); + bob_store + .session_store + .store_session(&alice_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let ptext = message_decrypt_signal( + &delayed_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("decrypt should ignore unmatched corrupt receiver chain"); + + assert_eq!(ptext, b"delayed old"); + } + + #[test] + fn decrypt_fails_on_corrupt_matched_receiver_chain() { + let mut rng = ChaCha8Rng::seed_from_u64(0xD311AA); + let (_alice_store, mut bob_store, alice_address, _bob_address, delayed_signal_msg) = + setup_two_alice_receiver_chains_on_bob(&mut rng); + + let bob_record = bob_store + .session_store + .load_session(&alice_address) + .now_or_never() + .expect("sync") + .expect("load succeeded") + .expect("session exists"); + let serialized = bob_record.serialize().expect("serialize"); + let mut record_pb = RecordStructure::decode(serialized.as_slice()).expect("decode record"); + let current_session = record_pb.current_session.as_mut().expect("current session"); + + let matched_key = delayed_signal_msg.sender_ratchet_key().serialize().to_vec(); + let matched_idx = current_session + .receiver_chains + .iter() + .position(|chain| chain.sender_ratchet_key == matched_key) + .expect("matching receiver chain present"); + current_session.receiver_chains[matched_idx] + .chain_key + .as_mut() + .expect("chain key present") + .key = vec![0xFF]; + + let corrupted_record = SessionRecord::deserialize(record_pb.encode_to_vec().as_slice()) + .expect("deserialize mutated record"); + bob_store + .session_store + .store_session(&alice_address, &corrupted_record) + .now_or_never() + .expect("sync") + .expect("store succeeded"); + + let err = message_decrypt_signal( + &delayed_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect_err("decrypt should fail on corrupt matched receiver chain"); + + assert!( + matches!( + err, + SignalProtocolError::InvalidMessage( + CiphertextMessageType::Whisper, + "decryption failed" + ) + ), + "unexpected error: {err:?}" + ); + } + + // ── Dual-path simulation helpers ──────────────────────────────────── + // + // Run every operation on both the refactored and legacy code paths, + // asserting identical outputs (ciphertexts, plaintexts, or error + // variants). RNG sync follows the same clone-before-each-op pattern + // as proptest_ciphertext_equality. + + fn assert_store_state_equivalent( + new_store: &InMemSignalProtocolStore, + leg_store: &InMemSignalProtocolStore, + peer_addr: &ProtocolAddress, + context: &str, + ) { + let new_session = new_store + .session_store + .load_session(peer_addr) + .now_or_never() + .expect("sync") + .expect("new load session"); + let leg_session = leg_store + .session_store + .load_session(peer_addr) + .now_or_never() + .expect("sync") + .expect("legacy load session"); + + let new_session_bytes = new_session.map(|record| record.serialize().expect("serialize")); + let leg_session_bytes = leg_session.map(|record| record.serialize().expect("serialize")); + assert_eq!( + new_session_bytes, leg_session_bytes, + "{context}: session records diverged" + ); + + let new_identity = new_store + .identity_store + .get_identity(peer_addr) + .now_or_never() + .expect("sync") + .expect("new load identity") + .map(|identity| identity.serialize()); + let leg_identity = leg_store + .identity_store + .get_identity(peer_addr) + .now_or_never() + .expect("sync") + .expect("legacy load identity") + .map(|identity| identity.serialize()); + assert_eq!( + new_identity, leg_identity, + "{context}: trusted identities diverged" + ); + } + + /// Encrypt on both paths with cloned RNG. Assert ciphertexts and sender + /// state are byte-identical. + fn dual_encrypt( + plaintext: &[u8], + recv_addr: &ProtocolAddress, + send_addr: &ProtocolAddress, + new_sender: &mut InMemSignalProtocolStore, + leg_sender: &mut InMemSignalProtocolStore, + now: SystemTime, + rng: &mut ChaCha8Rng, + ) -> SignalMessage { + let mut leg_rng = rng.clone(); + + let new_ct = message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut new_sender.session_store, + &mut new_sender.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("new encrypt"); + + let leg_ct = legacy::legacy_message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut leg_sender.session_store, + &mut leg_sender.identity_store, + now, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy encrypt"); + + assert_eq!( + new_ct.serialize(), + leg_ct.serialize(), + "new and legacy produced different ciphertexts" + ); + assert_eq!( + new_ct.message_type(), + leg_ct.message_type(), + "new and legacy produced different ciphertext types" + ); + assert_store_state_equivalent(new_sender, leg_sender, recv_addr, "encrypt"); + + match new_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage from dual_encrypt, got {:?}", + other.message_type() + ), + } + } + + /// Encrypt on both paths with cloned RNG. Assert full ciphertext + /// equivalence, including `PreKeySignalMessage`. + fn dual_encrypt_any( + plaintext: &[u8], + recv_addr: &ProtocolAddress, + send_addr: &ProtocolAddress, + new_sender: &mut InMemSignalProtocolStore, + leg_sender: &mut InMemSignalProtocolStore, + now: SystemTime, + rng: &mut ChaCha8Rng, + ) -> CiphertextMessage { + let mut leg_rng = rng.clone(); + + let new_ct = message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut new_sender.session_store, + &mut new_sender.identity_store, + now, + rng, + ) + .now_or_never() + .expect("sync") + .expect("new encrypt"); + + let leg_ct = legacy::legacy_message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut leg_sender.session_store, + &mut leg_sender.identity_store, + now, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy encrypt"); + + assert_eq!( + new_ct.serialize(), + leg_ct.serialize(), + "new and legacy produced different ciphertexts" + ); + assert_eq!( + new_ct.message_type(), + leg_ct.message_type(), + "new and legacy produced different ciphertext types" + ); + assert_store_state_equivalent(new_sender, leg_sender, recv_addr, "encrypt"); + new_ct + } + + /// Decrypt on both paths with cloned RNG. Assert plaintexts and receiver + /// state match. + fn dual_decrypt( + msg: &SignalMessage, + sender_addr: &ProtocolAddress, + new_receiver: &mut InMemSignalProtocolStore, + leg_receiver: &mut InMemSignalProtocolStore, + rng: &mut ChaCha8Rng, + ) -> Vec { + let mut leg_rng = rng.clone(); + + let new_pt = message_decrypt_signal( + msg, + sender_addr, + &mut new_receiver.session_store, + &mut new_receiver.identity_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect("new decrypt"); + + let leg_pt = legacy::legacy_message_decrypt_signal( + msg, + sender_addr, + &mut leg_receiver.session_store, + &mut leg_receiver.identity_store, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy decrypt"); + + assert_eq!( + new_pt, leg_pt, + "new and legacy produced different plaintexts" + ); + assert_store_state_equivalent(new_receiver, leg_receiver, sender_addr, "decrypt"); + new_pt + } + + /// Decrypt any ciphertext on both paths with cloned RNG. Assert + /// plaintexts and receiver-side state match. + fn dual_decrypt_any( + msg: &CiphertextMessage, + sender_addr: &ProtocolAddress, + receiver_addr: &ProtocolAddress, + new_receiver: &mut InMemSignalProtocolStore, + leg_receiver: &mut InMemSignalProtocolStore, + rng: &mut ChaCha8Rng, + ) -> Vec { + let mut leg_rng = rng.clone(); + + let new_pt = message_decrypt( + msg, + sender_addr, + receiver_addr, + &mut new_receiver.session_store, + &mut new_receiver.identity_store, + &mut new_receiver.pre_key_store, + &new_receiver.signed_pre_key_store, + &mut new_receiver.kyber_pre_key_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect("new decrypt"); + + let leg_pt = legacy::legacy_message_decrypt( + msg, + sender_addr, + receiver_addr, + &mut leg_receiver.session_store, + &mut leg_receiver.identity_store, + &mut leg_receiver.pre_key_store, + &leg_receiver.signed_pre_key_store, + &mut leg_receiver.kyber_pre_key_store, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy decrypt"); + + assert_eq!( + new_pt, leg_pt, + "new and legacy produced different plaintexts" + ); + assert_store_state_equivalent(new_receiver, leg_receiver, sender_addr, "decrypt"); + new_pt + } + + fn dual_decrypt_any_result( + msg: &CiphertextMessage, + sender_addr: &ProtocolAddress, + receiver_addr: &ProtocolAddress, + new_receiver: &mut InMemSignalProtocolStore, + leg_receiver: &mut InMemSignalProtocolStore, + rng: &mut ChaCha8Rng, + ) -> Result> { + let mut leg_rng = rng.clone(); + + let new_result = message_decrypt( + msg, + sender_addr, + receiver_addr, + &mut new_receiver.session_store, + &mut new_receiver.identity_store, + &mut new_receiver.pre_key_store, + &new_receiver.signed_pre_key_store, + &mut new_receiver.kyber_pre_key_store, + rng, + ) + .now_or_never() + .expect("sync"); + + let leg_result = legacy::legacy_message_decrypt( + msg, + sender_addr, + receiver_addr, + &mut leg_receiver.session_store, + &mut leg_receiver.identity_store, + &mut leg_receiver.pre_key_store, + &leg_receiver.signed_pre_key_store, + &mut leg_receiver.kyber_pre_key_store, + &mut leg_rng, + ) + .now_or_never() + .expect("sync"); + + match (new_result, leg_result) { + (Ok(new_pt), Ok(leg_pt)) => { + assert_eq!( + new_pt, leg_pt, + "new and legacy produced different plaintexts" + ); + assert_store_state_equivalent(new_receiver, leg_receiver, sender_addr, "decrypt"); + Ok(new_pt) + } + (Err(new_err), Err(leg_err)) => { + assert_eq!( + std::mem::discriminant(&new_err), + std::mem::discriminant(&leg_err), + "error variants differ: new={new_err:?}, legacy={leg_err:?}" + ); + assert_store_state_equivalent(new_receiver, leg_receiver, sender_addr, "decrypt"); + Err(new_err) + } + (new_result, leg_result) => panic!( + "new and legacy disagreed on decrypt result: new={new_result:?}, legacy={leg_result:?}" + ), + } + } + + /// Decrypt on both paths, assert both fail with the same error variant. + fn dual_decrypt_expect_err( + msg: &SignalMessage, + sender_addr: &ProtocolAddress, + new_receiver: &mut InMemSignalProtocolStore, + leg_receiver: &mut InMemSignalProtocolStore, + rng: &mut ChaCha8Rng, + ) -> SignalProtocolError { + let mut leg_rng = rng.clone(); + + let new_err = message_decrypt_signal( + msg, + sender_addr, + &mut new_receiver.session_store, + &mut new_receiver.identity_store, + rng, + ) + .now_or_never() + .expect("sync") + .expect_err("expected new decrypt to fail"); + + let leg_err = legacy::legacy_message_decrypt_signal( + msg, + sender_addr, + &mut leg_receiver.session_store, + &mut leg_receiver.identity_store, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect_err("expected legacy decrypt to fail"); + + assert_eq!( + std::mem::discriminant(&new_err), + std::mem::discriminant(&leg_err), + "error variants differ: new={new_err:?}, legacy={leg_err:?}" + ); + new_err + } + + // ── DualSession convenience wrapper ───────────────────────────────── + + /// Paired new+legacy session state for readable scenario tests. + struct DualSession { + na: InMemSignalProtocolStore, + nb: InMemSignalProtocolStore, + la: InMemSignalProtocolStore, + lb: InMemSignalProtocolStore, + alice: ProtocolAddress, + bob: ProtocolAddress, + rng: ChaCha8Rng, + now: SystemTime, + } + + impl DualSession { + fn new(seed: u64) -> Self { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let (na, nb, alice, bob) = setup_stores(&mut rng); + let (la, lb) = (na.clone(), nb.clone()); + Self { + na, + nb, + la, + lb, + alice, + bob, + rng, + now: SystemTime::now(), + } + } + + fn alice_sends(&mut self, plaintext: &[u8]) -> SignalMessage { + dual_encrypt( + plaintext, + &self.bob, + &self.alice, + &mut self.na, + &mut self.la, + self.now, + &mut self.rng, + ) + } + + fn bob_sends(&mut self, plaintext: &[u8]) -> SignalMessage { + dual_encrypt( + plaintext, + &self.alice, + &self.bob, + &mut self.nb, + &mut self.lb, + self.now, + &mut self.rng, + ) + } + + fn bob_receives(&mut self, msg: &SignalMessage) -> Vec { + dual_decrypt(msg, &self.alice, &mut self.nb, &mut self.lb, &mut self.rng) + } + + fn alice_receives(&mut self, msg: &SignalMessage) -> Vec { + dual_decrypt(msg, &self.bob, &mut self.na, &mut self.la, &mut self.rng) + } + + fn bob_receives_err(&mut self, msg: &SignalMessage) -> SignalProtocolError { + dual_decrypt_expect_err(msg, &self.alice, &mut self.nb, &mut self.lb, &mut self.rng) + } + + #[allow(dead_code)] + fn alice_receives_err(&mut self, msg: &SignalMessage) -> SignalProtocolError { + dual_decrypt_expect_err(msg, &self.bob, &mut self.na, &mut self.la, &mut self.rng) + } + } + + impl DualParticipant { + fn new( + _name: &'static str, + address: ProtocolAddress, + rng: &mut (impl rand::Rng + rand::CryptoRng), + ) -> Self { + let identity = IdentityKeyPair::generate(rng); + let store = InMemSignalProtocolStore::new(identity, rng.random()).unwrap(); + Self { + address, + message_queue: Vec::new(), + state: DualLocalState { + new_store: store.clone(), + legacy_store: store, + pre_key_count: 0, + }, + snapshots: Vec::new(), + message_send_log: Vec::new(), + } + } + + fn address(&self) -> &ProtocolAddress { + &self.address + } + + fn has_pending_incoming_messages(&self) -> bool { + !self.message_queue.is_empty() + } + + fn assert_equivalent_with(&self, them: &Self, context: &str) { + assert_store_state_equivalent( + &self.state.new_store, + &self.state.legacy_store, + &them.address, + context, + ); + } + + async fn process_pre_key( + &mut self, + them: &mut Self, + use_one_time_pre_key: bool, + rng: &mut ChaCha8Rng, + ) { + let their_signed_pre_key_pair = KeyPair::generate(rng); + let their_signed_pre_key_public = their_signed_pre_key_pair.public_key.serialize(); + let identity_key_pair = them.state.new_store.get_identity_key_pair().await.unwrap(); + let their_signed_pre_key_signature = identity_key_pair + .private_key() + .calculate_signature(&their_signed_pre_key_public, rng) + .unwrap(); + + them.state.pre_key_count += 1; + let signed_pre_key_id: SignedPreKeyId = them.state.pre_key_count.into(); + let signed_pre_key_record = SignedPreKeyRecord::new( + signed_pre_key_id, + Timestamp::from_epoch_millis(42), + &their_signed_pre_key_pair, + &their_signed_pre_key_signature, + ); + them.state + .new_store + .save_signed_pre_key(signed_pre_key_id, &signed_pre_key_record) + .await + .unwrap(); + them.state + .legacy_store + .save_signed_pre_key(signed_pre_key_id, &signed_pre_key_record) + .await + .unwrap(); + + them.state.pre_key_count += 1; + let pre_key_id: PreKeyId = them.state.pre_key_count.into(); + let pre_key_info = if use_one_time_pre_key { + let one_time_pre_key = KeyPair::generate(rng); + let pre_key_record = PreKeyRecord::new(pre_key_id, &one_time_pre_key); + them.state + .new_store + .save_pre_key(pre_key_id, &pre_key_record) + .await + .unwrap(); + them.state + .legacy_store + .save_pre_key(pre_key_id, &pre_key_record) + .await + .unwrap(); + Some((pre_key_id, one_time_pre_key.public_key)) + } else { + None + }; + + let their_kyber_pre_key_pair = + crate::kem::KeyPair::generate(crate::kem::KeyType::Kyber1024, rng); + let their_kyber_pre_key_public = their_kyber_pre_key_pair.public_key.serialize(); + let their_kyber_pre_key_signature = identity_key_pair + .private_key() + .calculate_signature(&their_kyber_pre_key_public, rng) + .unwrap(); + + them.state.pre_key_count += 1; + let kyber_pre_key_id: KyberPreKeyId = them.state.pre_key_count.into(); + let kyber_pre_key_record = KyberPreKeyRecord::new( + kyber_pre_key_id, + Timestamp::from_epoch_millis(42), + &their_kyber_pre_key_pair, + &their_kyber_pre_key_signature, + ); + them.state + .new_store + .save_kyber_pre_key(kyber_pre_key_id, &kyber_pre_key_record) + .await + .unwrap(); + them.state + .legacy_store + .save_kyber_pre_key(kyber_pre_key_id, &kyber_pre_key_record) + .await + .unwrap(); + + let their_pre_key_bundle = PreKeyBundle::new( + them.state + .new_store + .get_local_registration_id() + .await + .unwrap(), + DeviceId::new(1).unwrap(), + pre_key_info, + signed_pre_key_id, + their_signed_pre_key_pair.public_key, + their_signed_pre_key_signature.into_vec(), + kyber_pre_key_id, + their_kyber_pre_key_pair.public_key, + their_kyber_pre_key_signature.into_vec(), + *identity_key_pair.identity_key(), + ) + .unwrap(); + + let mut legacy_rng = rng.clone(); + process_prekey_bundle( + &them.address, + &mut self.state.new_store.session_store, + &mut self.state.new_store.identity_store, + &their_pre_key_bundle, + SystemTime::UNIX_EPOCH, + rng, + ) + .await + .unwrap(); + process_prekey_bundle( + &them.address, + &mut self.state.legacy_store.session_store, + &mut self.state.legacy_store.identity_store, + &their_pre_key_bundle, + SystemTime::UNIX_EPOCH, + &mut legacy_rng, + ) + .await + .unwrap(); + + self.assert_equivalent_with(them, "process_pre_key/self"); + them.assert_equivalent_with(self, "process_pre_key/them"); + assert!( + self.state + .new_store + .load_session(&them.address) + .await + .unwrap() + .expect("just created") + .has_usable_sender_chain( + SystemTime::UNIX_EPOCH, + SessionUsabilityRequirements::all(), + ) + .unwrap() + ); + } + + async fn send_message(&mut self, them: &mut Self, rng: &mut ChaCha8Rng) { + self.send_message_with_id(them, self.message_send_log.len().try_into().unwrap(), rng) + .await; + self.message_send_log.push(MessageStatus::Sent); + } + + async fn send_message_with_id(&mut self, them: &mut Self, id: u64, rng: &mut ChaCha8Rng) { + let has_usable_sender_chain = self + .state + .new_store + .load_session(&them.address) + .await + .unwrap() + .and_then(|session| { + session + .has_usable_sender_chain( + SystemTime::UNIX_EPOCH, + SessionUsabilityRequirements::all(), + ) + .ok() + }) + .unwrap_or(false); + + if !has_usable_sender_chain { + self.process_pre_key(them, rng.random_bool(0.75), rng).await; + } + + let buffer = id.to_le_bytes(); + let outgoing_message = dual_encrypt_any( + &buffer, + &them.address, + &self.address, + &mut self.state.new_store, + &mut self.state.legacy_store, + SystemTime::UNIX_EPOCH, + rng, + ); + + let incoming_message = match outgoing_message.message_type() { + CiphertextMessageType::PreKey => CiphertextMessage::PreKeySignalMessage( + PreKeySignalMessage::try_from(outgoing_message.serialize()).unwrap(), + ), + CiphertextMessageType::Whisper => CiphertextMessage::SignalMessage( + SignalMessage::try_from(outgoing_message.serialize()).unwrap(), + ), + other_type => panic!("unexpected type {other_type:?}"), + }; + + them.message_queue.push((incoming_message, id)); + self.assert_equivalent_with(them, "send"); + } + + async fn receive_messages(&mut self, them: &mut Self, rng: &mut ChaCha8Rng) { + for (incoming_message, expected) in self.message_queue.split_off(0) { + match incoming_message { + CiphertextMessage::SignalMessage(_) + | CiphertextMessage::PreKeySignalMessage(_) => { + match dual_decrypt_any_result( + &incoming_message, + &them.address, + &self.address, + &mut self.state.new_store, + &mut self.state.legacy_store, + rng, + ) { + Ok(decrypted) => { + assert_eq!(expected.to_le_bytes(), &decrypted[..]); + them.ack(expected); + } + Err(_) => { + let error_msg = DecryptionErrorMessage::for_original( + incoming_message.serialize(), + incoming_message.message_type(), + Timestamp::from_epoch_millis(expected), + 1, + ) + .expect("can encode DEM"); + them.message_queue.push(( + CiphertextMessage::PlaintextContent(error_msg.into()), + u64::MAX, + )); + } + } + } + CiphertextMessage::SenderKeyMessage(_) => unreachable!(), + CiphertextMessage::PlaintextContent(content) => { + self.handle_decryption_error(them, content, rng).await; + } + } + } + self.assert_equivalent_with(them, "receive"); + them.assert_equivalent_with(self, "receive/peer"); + } + + fn drop_message(&mut self, them: &mut Self) { + match self.message_queue.pop() { + None | Some((CiphertextMessage::PlaintextContent(_), _)) => {} + Some((_, id)) => them.nack(id), + } + } + + fn shuffle_messages(&mut self, rng: &mut impl rand::Rng) { + use rand::seq::SliceRandom as _; + self.message_queue.shuffle(rng); + } + + async fn handle_decryption_error( + &mut self, + them: &mut Self, + content: PlaintextContent, + rng: &mut ChaCha8Rng, + ) { + let dem = extract_decryption_error_message_from_serialized_content(content.body()) + .expect("all PlaintextContent is DEM"); + assert_eq!(dem.device_id(), 1); + + let id = dem.timestamp().epoch_millis(); + let Some(status) = self.message_send_log.get(usize::try_from(id).unwrap()) else { + panic!( + "failed to decrypt an unsent message {id} ({} total sent)", + self.message_send_log.len() + ) + }; + match status { + MessageStatus::Sent => {} + MessageStatus::Dropped => { + panic!("got a decryption error for dropped message {id}"); + } + MessageStatus::Delivered => { + panic!("got a decryption error for successfully delivered message {id}"); + } + } + + let ratchet_key = dem + .ratchet_key() + .expect("all DEMs for 1:1 messages have ratchet keys"); + if self + .state + .new_store + .load_session(&them.address) + .await + .unwrap() + .is_some_and(|session| { + session + .current_ratchet_key_matches(ratchet_key) + .expect("structurally valid session") + }) + { + self.archive_session(&them.address).await; + } + + self.send_message_with_id(them, id, rng).await; + } + + async fn archive_session(&mut self, their_address: &ProtocolAddress) { + for store in [&mut self.state.new_store, &mut self.state.legacy_store] { + if let Some(mut session) = store.load_session(their_address).await.unwrap() { + session.archive_current_state().unwrap(); + store.store_session(their_address, &session).await.unwrap(); + } + } + } + + fn snapshot_state(&mut self) { + self.snapshots.push(self.state.clone()); + } + + fn restore_from_snapshot_if_exists(&mut self, i: u8) { + let i = usize::from(i); + if i < self.snapshots.len() { + self.state = self.snapshots.remove(i); + } + } + + fn ack(&mut self, id: u64) { + self.update_status(id, MessageStatus::Delivered); + } + + fn nack(&mut self, id: u64) { + self.update_status(id, MessageStatus::Dropped); + } + + fn update_status(&mut self, id: u64, updated_status: MessageStatus) { + let Some(status) = self.message_send_log.get_mut(usize::try_from(id).unwrap()) else { + panic!( + "tried to update unsent message {id} ({} total sent)", + self.message_send_log.len() + ) + }; + match status { + MessageStatus::Sent => *status = updated_status, + MessageStatus::Dropped => panic!("updated dropped message {id}"), + MessageStatus::Delivered => panic!("updated delivered message {id}"), + } + } + + async fn run_event(&mut self, them: &mut Self, event: Event, rng: &mut ChaCha8Rng) { + match event { + Event::Archive => self.archive_session(them.address()).await, + Event::Snapshot => self.snapshot_state(), + Event::Restore { index } => self.restore_from_snapshot_if_exists(index), + Event::Receive => self.receive_messages(them, rng).await, + Event::Drop => self.drop_message(them), + Event::Shuffle => self.shuffle_messages(rng), + Event::Send { count_times_eight } => { + for _ in 0..(count_times_eight / 8) { + self.send_message(them, rng).await; + } + } + } + self.assert_equivalent_with(them, "event"); + them.assert_equivalent_with(self, "event/peer"); + } + } + + // ── Scenario tests ────────────────────────────────────────────────── + + /// Ordinary skipped-key handling remains interoperable when messages are + /// delivered with gaps, later arrive out of order, and both directions + /// continue sending before the session fully catches up. + #[test] + fn scenario_interleaved_delivery_with_gaps_and_recovery() { + let mut s = DualSession::new(0xBEEF_0001); + + // Alice sends a burst to Bob. Bob receives only the first and third, + // leaving a gap that must be recovered later from stored skipped keys. + let a_msgs: Vec<_> = (0u8..4) + .map(|i| (s.alice_sends(&[b'A', i]), vec![b'A', i])) + .collect(); + assert_eq!(s.bob_receives(&a_msgs[0].0), a_msgs[0].1, "alice msg 0"); + assert_eq!(s.bob_receives(&a_msgs[2].0), a_msgs[2].1, "alice msg 2"); + + // Before Alice's burst is fully drained, Bob sends his own burst. + // Alice receives only the later message first, exercising the same + // skipped-key path in the opposite direction. + let b_msgs: Vec<_> = (0u8..3) + .map(|i| (s.bob_sends(&[b'B', i]), vec![b'B', i])) + .collect(); + assert_eq!(s.alice_receives(&b_msgs[2].0), b_msgs[2].1, "bob msg 2"); + + // The missing earlier messages now arrive and must still decrypt. + assert_eq!(s.bob_receives(&a_msgs[1].0), a_msgs[1].1, "alice msg 1"); + assert_eq!(s.bob_receives(&a_msgs[3].0), a_msgs[3].1, "alice msg 3"); + assert_eq!(s.alice_receives(&b_msgs[0].0), b_msgs[0].1, "bob msg 0"); + assert_eq!(s.alice_receives(&b_msgs[1].0), b_msgs[1].1, "bob msg 1"); + + // After recovering the gaps, both directions should continue in steady + // state without any special handling. + let alice_followup = s.alice_sends(b"alice steady"); + assert_eq!(s.bob_receives(&alice_followup), b"alice steady"); + let bob_followup = s.bob_sends(b"bob steady"); + assert_eq!(s.alice_receives(&bob_followup), b"bob steady"); + } + + /// Skip past MAX_FORWARD_JUMPS — both paths must reject with the same + /// error. Encrypts on just the new path for performance (25k+ messages); + /// both receivers start from identical untouched state. + #[test] + fn scenario_chain_jump_over_limit() { + let mut rng = ChaCha8Rng::seed_from_u64(0xBEEF_0005); + let (mut na, mut nb, alice, bob) = setup_stores(&mut rng); + let mut lb = nb.clone(); + let now = SystemTime::now(); + + let count = crate::consts::MAX_FORWARD_JUMPS + 2; + let mut last = None; + for _ in 0..count { + let ct = message_encrypt( + b"x", + &bob, + &alice, + &mut na.session_store, + &mut na.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("encrypt"); + last = Some(match ct { + CiphertextMessage::SignalMessage(m) => m, + _ => panic!("not SignalMessage"), + }); + } + let msg = last.unwrap(); + + // Bob has received nothing — new and legacy stores are identical. + let mut leg_rng = rng.clone(); + let new_err = message_decrypt_signal( + &msg, + &alice, + &mut nb.session_store, + &mut nb.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect_err("should exceed jump limit"); + + let leg_err = legacy::legacy_message_decrypt_signal( + &msg, + &alice, + &mut lb.session_store, + &mut lb.identity_store, + &mut leg_rng, + ) + .now_or_never() + .expect("sync") + .expect_err("should exceed jump limit"); + + assert_eq!( + std::mem::discriminant(&new_err), + std::mem::discriminant(&leg_err), + "error variants differ: new={new_err:?}, legacy={leg_err:?}" + ); + assert!( + matches!(new_err, SignalProtocolError::InvalidMessage(..)), + "expected InvalidMessage, got {new_err:?}" + ); + } + + /// The initial unacknowledged send must be bit-identical as a + /// `PreKeySignalMessage`, and both sides must end up with identical + /// session state after the ack round-trip. + #[test] + fn scenario_prekey_session_establishment_equivalence() { + let mut rng = ChaCha8Rng::seed_from_u64(0xBEEF_0008); + let alice_identity = IdentityKeyPair::generate(&mut rng); + let bob_identity = IdentityKeyPair::generate(&mut rng); + let alice = ProtocolAddress::new( + "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + DeviceId::new(1).unwrap(), + ); + let bob = ProtocolAddress::new( + "796abedb-ca4e-4f18-8803-1fde5b921f9f".to_owned(), + DeviceId::new(1).unwrap(), + ); + let now = SystemTime::now(); + + let alice_base = InMemSignalProtocolStore::new(alice_identity, 1).expect("alice store"); + let mut bob_base = InMemSignalProtocolStore::new(bob_identity, 2).expect("bob store"); + let bundle = create_bob_bundle(&mut bob_base, 1, 1, 1, &mut rng); + + let (mut alice_new, mut alice_legacy) = (alice_base.clone(), alice_base.clone()); + let (mut bob_new, mut bob_legacy) = (bob_base.clone(), bob_base.clone()); + + let mut legacy_rng = rng.clone(); + process_prekey_bundle( + &bob, + &mut alice_new.session_store, + &mut alice_new.identity_store, + &bundle, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new process_prekey_bundle"); + process_prekey_bundle( + &bob, + &mut alice_legacy.session_store, + &mut alice_legacy.identity_store, + &bundle, + now, + &mut legacy_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy process_prekey_bundle"); + assert_store_state_equivalent(&alice_new, &alice_legacy, &bob, "post-bundle"); + + let init = dual_encrypt_any( + b"session init", + &bob, + &alice, + &mut alice_new, + &mut alice_legacy, + now, + &mut rng, + ); + assert!( + matches!(init, CiphertextMessage::PreKeySignalMessage(_)), + "expected first message after bundle processing to be PreKey" + ); + + assert_eq!( + dual_decrypt_any(&init, &alice, &bob, &mut bob_new, &mut bob_legacy, &mut rng), + b"session init" + ); + + let ack = dual_encrypt_any( + b"session ack", + &alice, + &bob, + &mut bob_new, + &mut bob_legacy, + now, + &mut rng, + ); + assert!( + matches!(ack, CiphertextMessage::SignalMessage(_)), + "expected ack to be a SignalMessage" + ); + assert_eq!( + dual_decrypt_any( + &ack, + &bob, + &alice, + &mut alice_new, + &mut alice_legacy, + &mut rng + ), + b"session ack" + ); + + let followup = dual_encrypt_any( + b"steady state", + &bob, + &alice, + &mut alice_new, + &mut alice_legacy, + now, + &mut rng, + ); + assert!( + matches!(followup, CiphertextMessage::SignalMessage(_)), + "expected acknowledged session to emit SignalMessage" + ); + assert_eq!( + dual_decrypt_any( + &followup, + &alice, + &bob, + &mut bob_new, + &mut bob_legacy, + &mut rng + ), + b"steady state" + ); + } + + /// Flip a byte in the MAC — both paths must reject identically, and + /// the original message must still decrypt afterward. + #[test] + fn scenario_corrupted_ciphertext() { + let mut s = DualSession::new(0xBEEF_0006); + + // Warm up the session with a round-trip + let msg = s.alice_sends(b"hello"); + assert_eq!(s.bob_receives(&msg), b"hello"); + let msg = s.bob_sends(b"hi"); + assert_eq!(s.alice_receives(&msg), b"hi"); + + // Alice sends a message; corrupt the last byte (in the MAC) + let msg = s.alice_sends(b"secret"); + let mut corrupted_bytes = msg.serialized().to_vec(); + let len = corrupted_bytes.len(); + corrupted_bytes[len - 1] ^= 0xFF; + let corrupted = + SignalMessage::try_from(corrupted_bytes.as_slice()).expect("parse corrupted message"); + + let err = s.bob_receives_err(&corrupted); + assert!( + matches!( + err, + SignalProtocolError::InvalidMessage(CiphertextMessageType::Whisper, _) + ), + "expected InvalidMessage(Whisper, _), got {err:?}" + ); + + // The original (uncorrupted) message still decrypts — failed MAC + // check does not persist state changes. + assert_eq!(s.bob_receives(&msg), b"secret"); + } + + /// Replay an already-decrypted message — both paths must detect the + /// duplicate. + #[test] + fn scenario_replay_message() { + let mut s = DualSession::new(0xBEEF_0007); + + let msg = s.alice_sends(b"once"); + assert_eq!(s.bob_receives(&msg), b"once"); + + // Same ciphertext again — should be detected as duplicate + let err = s.bob_receives_err(&msg); + assert!( + matches!(err, SignalProtocolError::DuplicatedMessage(..)), + "expected DuplicatedMessage, got {err:?}" + ); + } + + proptest! { + /// Reuse the existing session-reset event model from `test-support`, + /// but execute every encrypt/decrypt on both new and legacy codepaths. + #[test] + fn proptest_event_model_matches_legacy( + actions in prop::collection::vec( + (prop::bool::ANY, proptest_arbitrary_interop::arb::()), + 0..40, + ), + ) { + let mut rng = ChaCha8Rng::seed_from_u64(0); + let mut alice = DualParticipant::new( + "alice", + ProtocolAddress::new("9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), DeviceId::new(1).unwrap()), + &mut rng, + ); + let mut bob = DualParticipant::new( + "bob", + ProtocolAddress::new("796abedb-ca4e-4f18-8803-1fde5b921f9f".to_owned(), DeviceId::new(1).unwrap()), + &mut rng, + ); + + for (who, event) in actions { + let (me, them) = if who { + (&mut alice, &mut bob) + } else { + (&mut bob, &mut alice) + }; + me.run_event(them, event, &mut rng) + .now_or_never() + .expect("sync"); + } + + while alice.has_pending_incoming_messages() || bob.has_pending_incoming_messages() { + alice + .receive_messages(&mut bob, &mut rng) + .now_or_never() + .expect("sync"); + bob.receive_messages(&mut alice, &mut rng) + .now_or_never() + .expect("sync"); + } + + for _ in 0..8 { + alice + .send_message(&mut bob, &mut rng) + .now_or_never() + .expect("sync"); + bob.receive_messages(&mut alice, &mut rng) + .now_or_never() + .expect("sync"); + bob.send_message(&mut alice, &mut rng) + .now_or_never() + .expect("sync"); + alice + .receive_messages(&mut bob, &mut rng) + .now_or_never() + .expect("sync"); + } + + alice.assert_equivalent_with(&bob, "final/alice"); + bob.assert_equivalent_with(&alice, "final/bob"); + } + } + + proptest! { + /// New code can take over a session whose state was last written by + /// legacy code. + /// + /// Runs `legacy_actions` using legacy enc+dec on both sides, then + /// switches both sides to new enc+dec for `new_actions`. The session + /// state — chain keys, ratchet state, SPQR state — was written by the + /// legacy decrypt path; new code must read and advance it correctly. + #[test] + fn proptest_legacy_handover_to_new( + seed in 0u64..u64::MAX, + legacy_actions in prop::collection::vec( + (prop::bool::ANY, prop::collection::vec(any::(), 0..=64)), + 1..=10, + ), + new_actions in prop::collection::vec( + (prop::bool::ANY, prop::collection::vec(any::(), 0..=64)), + 1..=10, + ), + ) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let (mut alice_store, mut bob_store, alice_address, bob_address) = + setup_stores(&mut rng); + let now = std::time::SystemTime::now(); + + // Phase 1: legacy enc + legacy dec advance the session state. + for (alice_sends, plaintext) in &legacy_actions { + let (sender, receiver, recv_addr, send_addr) = if *alice_sends { + (&mut alice_store, &mut bob_store, &bob_address, &alice_address) + } else { + (&mut bob_store, &mut alice_store, &alice_address, &bob_address) + }; + + let ct = legacy::legacy_message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut sender.session_store, + &mut sender.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy enc"); + + let signal_msg = match &ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage in legacy phase, got {:?}", + other.message_type() + ), + }; + + let ptext = legacy::legacy_message_decrypt_signal( + signal_msg, + send_addr, + &mut receiver.session_store, + &mut receiver.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy dec"); + + prop_assert_eq!(ptext, plaintext.clone(), "legacy phase: wrong plaintext"); + } + + // Phase 2: new code takes over the session state left by legacy. + for (alice_sends, plaintext) in &new_actions { + let (sender, receiver, recv_addr, send_addr) = if *alice_sends { + (&mut alice_store, &mut bob_store, &bob_address, &alice_address) + } else { + (&mut bob_store, &mut alice_store, &alice_address, &bob_address) + }; + + let ct = message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut sender.session_store, + &mut sender.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new enc"); + + let signal_msg = match &ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage in new phase, got {:?}", + other.message_type() + ), + }; + + let ptext = message_decrypt_signal( + signal_msg, + send_addr, + &mut receiver.session_store, + &mut receiver.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new dec"); + + prop_assert_eq!(ptext, plaintext.clone(), "new phase: wrong plaintext"); + } + } + + /// A message encrypted by legacy code on a previous session is correctly + /// decrypted by new code after a session transition. + /// + /// Scenario: + /// 1. `pre_actions` exchanges on session A using legacy enc+dec. + /// 2. Alice encrypts a delayed Whisper on session A using legacy enc + /// (not yet delivered to Bob). + /// 3. Alice processes a new pre-key bundle from Bob → session B + /// (session A is archived in Alice's previous_sessions). + /// 4. Alice and Bob establish session B on both sides and exchange + /// `post_actions` using new enc+dec. Bob's session A' is archived to + /// his previous_sessions when he receives Alice's first session-B + /// PreKeySignalMessage. + /// 5. The delayed message from step 2 is delivered to Bob via new + /// message_decrypt_signal. try_decrypt_from_record must fail on + /// the current session B' and succeed on the previous session A', + /// exercising promote_old_session. + #[test] + fn proptest_delayed_message_via_previous_session( + seed in 0u64..u64::MAX, + pre_actions in prop::collection::vec( + (prop::bool::ANY, prop::collection::vec(any::(), 0..=6)), + 0..=6, + ), + post_actions in prop::collection::vec( + (prop::bool::ANY, prop::collection::vec(any::(), 0..=64)), + 0..=6, + ), + delayed_plaintext in prop::collection::vec(any::(), 1..=64), + ) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let (mut alice_store, mut bob_store, alice_address, bob_address) = + setup_stores(&mut rng); + let now = std::time::SystemTime::now(); + + // ── Phase 1: legacy enc+dec on session A ───────────────────────── + + for (alice_sends, plaintext) in &pre_actions { + let (sender, receiver, recv_addr, send_addr) = if *alice_sends { + (&mut alice_store, &mut bob_store, &bob_address, &alice_address) + } else { + (&mut bob_store, &mut alice_store, &alice_address, &bob_address) + }; + + let ct = legacy::legacy_message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut sender.session_store, + &mut sender.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("pre legacy enc"); + + let signal_msg = match &ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage in pre phase, got {:?}", + other.message_type() + ), + }; + + let ptext = legacy::legacy_message_decrypt_signal( + signal_msg, + send_addr, + &mut receiver.session_store, + &mut receiver.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("pre legacy dec"); + + prop_assert_eq!(ptext, plaintext.clone(), "pre phase: wrong plaintext"); + } + + // ── Encrypt delayed message (not yet delivered) ────────────────── + + // This Whisper is encrypted by Alice on session A's current chain. + // Bob's session A' is at the same chain index, so it can decrypt + // it later. + let delayed_ct = legacy::legacy_message_encrypt( + &delayed_plaintext, + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("delayed legacy enc"); + + let delayed_signal_msg = match delayed_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage for delayed msg, got {:?}", + other.message_type() + ), + }; + + // ── Session transition: A → B ───────────────────────────────────── + + // Alice processes a new pre-key bundle from Bob. This calls + // promote_state, archiving session A to Alice's previous_sessions. + let bundle = create_bob_bundle(&mut bob_store, 1, 1, 1, &mut rng); + process_prekey_bundle( + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &bundle, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("process_prekey_bundle"); + + // Alice sends her first message on session B (a PreKeySignalMessage + // since B is unacknowledged). When Bob decrypts it, process_prekey + // fires and archives his session A' to previous_sessions. + let session_b_init = message_encrypt( + b"session B init", + &bob_address, + &alice_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B init enc"); + + message_decrypt( + &session_b_init, + &alice_address, + &bob_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut bob_store.pre_key_store, + &bob_store.signed_pre_key_store, + &mut bob_store.kyber_pre_key_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B init dec"); + // Bob now has: current = session B', previous_sessions = [session A'] + + // Bob acknowledges session B on Alice's side. Without this, Alice + // would keep wrapping messages as PreKeySignalMessage, and each + // would trigger another process_prekey on Bob's side, nesting + // sessions further. After this round-trip both sides send Whispers. + let session_b_ack = message_encrypt( + b"session B ack", + &alice_address, + &bob_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B ack enc"); + + let session_b_ack_signal = match &session_b_ack { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected Whisper for session B ack, got {:?}", + other.message_type() + ), + }; + message_decrypt_signal( + session_b_ack_signal, + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("session B ack dec"); + // Alice's session B is now acknowledged; all her sends are Whispers. + + // ── Phase 2: new enc+dec on session B ──────────────────────────── + + for (alice_sends, plaintext) in &post_actions { + let (sender, receiver, recv_addr, send_addr) = if *alice_sends { + (&mut alice_store, &mut bob_store, &bob_address, &alice_address) + } else { + (&mut bob_store, &mut alice_store, &alice_address, &bob_address) + }; + + let ct = message_encrypt( + plaintext, + recv_addr, + send_addr, + &mut sender.session_store, + &mut sender.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("post new enc"); + + let signal_msg = match &ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage in post phase, got {:?}", + other.message_type() + ), + }; + + let ptext = message_decrypt_signal( + signal_msg, + send_addr, + &mut receiver.session_store, + &mut receiver.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("post new dec"); + + prop_assert_eq!(ptext, plaintext.clone(), "post phase: wrong plaintext"); + } + + // ── Deliver delayed message ─────────────────────────────────────── + + // Bob's current session is B'. The delayed message was encrypted + // under session A. try_decrypt_from_record must: + // 1. Try session B' → fail (wrong ratchet key / counter). + // 2. Try session A' from previous_sessions → succeed. + // 3. Call promote_old_session, making A' the current session. + let ptext = message_decrypt_signal( + &delayed_signal_msg, + &alice_address, + &mut bob_store.session_store, + &mut bob_store.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("delayed msg dec via previous session"); + + prop_assert_eq!( + ptext, + delayed_plaintext.clone(), + "delayed message: wrong plaintext" + ); + } + + /// New encrypt and legacy encrypt produce byte-identical ciphertexts + /// when given the same RNG state. + /// + /// Runs two parallel session pairs from the same initial state. + /// Before each encrypt, the RNG is cloned so that both the new and + /// legacy paths start from the same randomness. If the RNG + /// consumption is identical (one `spqr::send` call per encrypt, + /// one `KeyPair::generate` per DH ratchet step on decrypt), the + /// ciphertexts must be equal. The receiver sessions are advanced + /// with the same split so that subsequent iterations stay in sync. + #[test] + fn proptest_ciphertext_equality( + seed in 0u64..u64::MAX, + actions in prop::collection::vec( + (prop::bool::ANY, prop::collection::vec(any::(), 0..=64)), + 1..=20, + ), + ) { + let mut rng = ChaCha8Rng::seed_from_u64(seed); + let (mut alice_new, mut bob_new, alice_address, bob_address) = + setup_stores(&mut rng); + // Clone the freshly-initialized stores so both paths start from + // identical state. + let (mut alice_legacy, mut bob_legacy) = (alice_new.clone(), bob_new.clone()); + let now = SystemTime::now(); + + for (alice_sends, plaintext) in &actions { + // Borrow the right stores for sender/receiver on each path. + let ( + (sender_new, receiver_new), + (sender_legacy, receiver_legacy), + sender_addr, + receiver_addr, + ) = if *alice_sends { + ( + (&mut alice_new, &mut bob_new), + (&mut alice_legacy, &mut bob_legacy), + &alice_address, + &bob_address, + ) + } else { + ( + (&mut bob_new, &mut alice_new), + (&mut bob_legacy, &mut alice_legacy), + &bob_address, + &alice_address, + ) + }; + + // Both encrypt calls start from the same RNG position. + let mut enc_rng = rng.clone(); + + let new_ct = message_encrypt( + plaintext, + receiver_addr, + sender_addr, + &mut sender_new.session_store, + &mut sender_new.identity_store, + now, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new encrypt succeeded"); + + let legacy_ct = legacy::legacy_message_encrypt( + plaintext, + receiver_addr, + sender_addr, + &mut sender_legacy.session_store, + &mut sender_legacy.identity_store, + now, + &mut enc_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy encrypt succeeded"); + + let new_msg = match &new_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage from new enc, got {:?}", + other.message_type() + ), + }; + let legacy_msg = match &legacy_ct { + CiphertextMessage::SignalMessage(m) => m, + other => panic!( + "expected SignalMessage from legacy enc, got {:?}", + other.message_type() + ), + }; + + prop_assert_eq!( + new_msg.serialized(), + legacy_msg.serialized(), + "new and legacy produced different ciphertexts from the same RNG state" + ); + + // Advance both receiver sessions with the same RNG split so + // their states stay in sync for the next iteration. + let mut dec_rng = rng.clone(); + + let _ = message_decrypt_signal( + new_msg, + sender_addr, + &mut receiver_new.session_store, + &mut receiver_new.identity_store, + &mut rng, + ) + .now_or_never() + .expect("sync") + .expect("new decrypt succeeded"); + + let _ = legacy::legacy_message_decrypt_signal( + legacy_msg, + sender_addr, + &mut receiver_legacy.session_store, + &mut receiver_legacy.identity_store, + &mut dec_rng, + ) + .now_or_never() + .expect("sync") + .expect("legacy decrypt succeeded"); + } + } + } +} diff --git a/rust/protocol/src/state/session.rs b/rust/protocol/src/state/session.rs index 92bd67965..0885999f6 100644 --- a/rust/protocol/src/state/session.rs +++ b/rust/protocol/src/state/session.rs @@ -8,18 +8,21 @@ use std::time::{Duration, SystemTime}; use bitflags::bitflags; use prost::Message; +#[cfg(test)] use rand::{CryptoRng, Rng}; use subtle::ConstantTimeEq; use crate::proto::storage::{RecordStructure, SessionStructure, session_structure}; use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION; -use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey}; +#[cfg(test)] +use crate::ratchet::MessageKeyGenerator; +use crate::ratchet::{ChainKey, RootKey}; use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId}; use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem}; /// A distinct error type to keep from accidentally propagating deserialization errors. #[derive(Debug)] -pub(crate) struct InvalidSessionError(&'static str); +pub(crate) struct InvalidSessionError(pub(crate) &'static str); impl std::fmt::Display for InvalidSessionError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -214,10 +217,12 @@ impl SessionState { self.session.previous_counter } + #[cfg(test)] pub(crate) fn set_previous_counter(&mut self, ctr: u32) { self.session.previous_counter = ctr; } + #[cfg(test)] pub(crate) fn root_key(&self) -> Result { let root_key_bytes = self.session.root_key[..] .try_into() @@ -225,6 +230,7 @@ impl SessionState { Ok(RootKey::new(root_key_bytes)) } + #[cfg(test)] pub(crate) fn set_root_key(&mut self, root_key: &RootKey) { self.session.root_key = root_key.key().to_vec(); } @@ -428,6 +434,7 @@ impl SessionState { self.session.sender_chain = Some(new_chain); } + #[cfg(test)] pub(crate) fn get_message_keys( &mut self, sender: &PublicKey, @@ -454,6 +461,7 @@ impl SessionState { Ok(None) } + #[cfg(test)] pub(crate) fn set_message_keys( &mut self, sender: &PublicKey, @@ -474,6 +482,7 @@ impl SessionState { Ok(()) } + #[cfg(test)] pub(crate) fn set_receiver_chain_key( &mut self, sender: &PublicKey, @@ -598,6 +607,39 @@ impl SessionState { .map(|pending| &pending.ciphertext) } + /// Extract the Double Ratchet state from this session. + /// + /// This **moves** the receiver chains out of the session protobuf, + /// leaving them empty. The caller must either apply the ratchet state + /// back via [`apply_ratchet_state`](Self::apply_ratchet_state) or + /// discard this `SessionState` entirely. Reading receiver chains from + /// a taken session is invalid and will produce incorrect results. + /// + /// The caller must supply `self_session` (whether this is a note-to-self + /// session) since it requires identity key comparison, which is above the + /// ratchet layer. + pub(crate) fn take_ratchet_state( + &mut self, + self_session: bool, + ) -> crate::error::Result { + let receiver_chains = std::mem::take(&mut self.session.receiver_chains); + Ok(crate::double_ratchet::RatchetState::from_pb( + &self.session, + self_session, + receiver_chains, + )?) + } + + /// Write modified Double Ratchet state back into this session. + /// + /// Only the ratchet-relevant fields are updated; all other session + /// fields (identity keys, pending pre-key, SPQR state, etc.) are + /// left unchanged. + pub(crate) fn apply_ratchet_state(&mut self, ratchet: crate::double_ratchet::RatchetState) { + ratchet.apply_to_pb(&mut self.session); + } + + #[cfg(test)] pub(crate) fn pq_ratchet_recv( &mut self, msg: &spqr::SerializedMessage, @@ -608,6 +650,7 @@ impl SessionState { Ok(key) } + #[cfg(test)] pub(crate) fn pq_ratchet_send( &mut self, csprng: &mut R, @@ -621,6 +664,19 @@ impl SessionState { pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState { &self.session.pq_ratchet_state } + + /// Move the PQ ratchet state out, leaving an empty vec in its place. + /// + /// The caller must either write a new PQ ratchet state back via + /// [`set_pq_ratchet_state`](Self::set_pq_ratchet_state) or discard + /// this `SessionState` entirely. + pub(crate) fn take_pq_ratchet_state(&mut self) -> spqr::SerializedState { + std::mem::take(&mut self.session.pq_ratchet_state) + } + + pub(crate) fn set_pq_ratchet_state(&mut self, state: spqr::SerializedState) { + self.session.pq_ratchet_state = state; + } } impl From for SessionState { diff --git a/rust/protocol/src/triple_ratchet.rs b/rust/protocol/src/triple_ratchet.rs new file mode 100644 index 000000000..e150239c5 --- /dev/null +++ b/rust/protocol/src/triple_ratchet.rs @@ -0,0 +1,311 @@ +// +// Copyright 2026 Signal Messenger, LLC. +// SPDX-License-Identifier: AGPL-3.0-only +// + +//! The Triple Ratchet: Double Ratchet + SPQR combined into a single +//! encrypt/decrypt interface. +//! +//! Handles MAC computation/verification and AES-CBC encryption/decryption. +//! +//! The session management layer treats this as an opaque box: give it +//! plaintext, get ciphertext; give it ciphertext, get plaintext. Identity +//! checking, session selection, pre-key handling, and storage are NOT +//! this layer's concern. + +use rand::{CryptoRng, Rng}; + +use crate::double_ratchet::RatchetState; +use crate::ratchet::ChainKey; +use crate::session_management::CurrentOrPrevious; +use crate::state::SessionState; +use crate::{ + CiphertextMessageType, IdentityKey, KeyPair, ProtocolAddress, Result, SignalMessage, + SignalProtocolError, +}; + +/// Sender-side Triple Ratchet session. +/// +/// This is intentionally narrower than [`TripleRatchet`]: encrypt only +/// depends on the current sender chain, SPQR state, identities, and metadata. +/// It does not deserialize receiver chains, so corrupt cold receiver-chain +/// state does not block sending. +/// +/// # Ownership contract +/// +/// [`from_session_state`](Self::from_session_state) **moves** the PQ ratchet +/// state out of the session (via [`SessionState::take_pq_ratchet_state`]), +/// leaving it temporarily invalid. The caller must either call +/// [`apply_to_session_state`](Self::apply_to_session_state) on success, or +/// discard the `SessionState` entirely. +pub(crate) struct OutgoingTripleRatchet { + sender_ratchet_key: KeyPair, + sender_chain_key: ChainKey, + previous_counter: u32, + pqr_state: spqr::SerializedState, + session_version: u8, + local_identity_key: IdentityKey, + remote_identity_key: IdentityKey, +} + +impl OutgoingTripleRatchet { + pub(crate) fn from_session_state(state: &mut SessionState) -> Result { + let sender_ratchet_key = KeyPair { + public_key: state.sender_ratchet_key()?, + private_key: state.sender_ratchet_private_key()?, + }; + let sender_chain_key = state.get_sender_chain_key()?; + let pqr_state = state.take_pq_ratchet_state(); + let session_version: u8 = state.session_version()?.try_into().map_err(|_| { + SignalProtocolError::InvalidSessionStructure("version does not fit in u8") + })?; + let local_identity_key = state.local_identity_key()?; + let remote_identity_key = + state + .remote_identity_key()? + .ok_or(SignalProtocolError::InvalidSessionStructure( + "missing remote identity key", + ))?; + + Ok(Self { + sender_ratchet_key, + sender_chain_key, + previous_counter: state.previous_counter(), + pqr_state, + session_version, + local_identity_key, + remote_identity_key, + }) + } + + pub(crate) fn apply_to_session_state(self, state: &mut SessionState) { + state.set_sender_chain_key(&self.sender_chain_key); + state.set_pq_ratchet_state(self.pqr_state); + } + + pub(crate) fn encrypt( + &mut self, + plaintext: &[u8], + local_address: Option<&ProtocolAddress>, + remote_address: &ProtocolAddress, + csprng: &mut R, + ) -> Result { + let spqr::Send { + state: new_pqr_state, + key: pqr_key, + msg: pqr_msg, + } = spqr::send(&self.pqr_state, csprng).map_err(|e| { + SignalProtocolError::InvalidState( + "encrypt", + format!("post-quantum ratchet send error: {e}"), + ) + })?; + + let message_keys = self.sender_chain_key.message_keys().generate_keys(pqr_key); + + let ctext = signal_crypto::aes_256_cbc_encrypt( + plaintext, + message_keys.cipher_key(), + message_keys.iv(), + ) + .map_err(|_| { + log::error!("session state corrupt for {remote_address}"); + SignalProtocolError::InvalidSessionStructure("invalid sender chain message keys") + })?; + + let addresses = local_address.map(|addr| (addr, remote_address)); + + let message = SignalMessage::new( + self.session_version, + message_keys.mac_key(), + addresses, + self.sender_ratchet_key.public_key, + self.sender_chain_key.index(), + self.previous_counter, + &ctext, + &self.local_identity_key, + &self.remote_identity_key, + &pqr_msg, + )?; + + self.sender_chain_key = self.sender_chain_key.next_chain_key(); + self.pqr_state = new_pqr_state; + + Ok(message) + } + + pub(crate) fn session_version(&self) -> u8 { + self.session_version + } + + pub(crate) fn local_identity_key(&self) -> &IdentityKey { + &self.local_identity_key + } +} + +/// A Triple Ratchet session combining Double Ratchet and SPQR. +/// +/// Constructed from a [`SessionState`], this extracts the cryptographic +/// state needed for decrypt into typed fields. After a successful +/// operation, call [`apply_to_session_state`](Self::apply_to_session_state) +/// to write the updated state back. +/// +/// # Ownership contract +/// +/// [`from_session_state`](Self::from_session_state) **moves** the receiver +/// chains and PQ ratchet state out of the session, leaving it temporarily invalid. +/// The caller must either: +/// - Call [`apply_to_session_state`](Self::apply_to_session_state) on success, or +/// - Discard the `SessionState` (e.g., it was a clone for trial decrypt). +pub(crate) struct TripleRatchet { + ratchet: RatchetState, + pqr_state: spqr::SerializedState, + local_identity_key: IdentityKey, + remote_identity_key: IdentityKey, +} + +impl TripleRatchet { + /// Construct from a [`SessionState`] by extracting crypto state. + /// + /// This moves receiver chains and PQ ratchet state out of `state`. + /// See the [ownership contract](Self#ownership-contract) for details. + /// + /// Fails if the session is missing required fields (root key, identity + /// keys, etc.). The caller should map the error appropriately for the + /// context (e.g., "no session available to decrypt"). + pub(crate) fn from_session_state(state: &mut SessionState) -> Result { + let self_session = state.session_with_self()?; + let ratchet = state.take_ratchet_state(self_session)?; + let pqr_state = state.take_pq_ratchet_state(); + let local_identity_key = state.local_identity_key()?; + let remote_identity_key = + state + .remote_identity_key()? + .ok_or(SignalProtocolError::InvalidSessionStructure( + "missing remote identity key", + ))?; + + Ok(Self { + ratchet, + pqr_state, + local_identity_key, + remote_identity_key, + }) + } + + /// Write the updated crypto state back to a [`SessionState`]. + /// + /// Only call this after a successful decrypt — this is how we ensure + /// no state pollution on failure. + pub(crate) fn apply_to_session_state(self, state: &mut SessionState) { + state.apply_ratchet_state(self.ratchet); + state.set_pq_ratchet_state(self.pqr_state); + } + + // -- Decrypt ------------------------------------------------------- + + /// Decrypt a [`SignalMessage`] to plaintext. + /// + /// Performs DR chain key derivation, SPQR key derivation, MAC + /// verification, and AES-CBC decryption. Ratchet and SPQR state are + /// only committed on success — a failed MAC or decryption leaves this + /// session unchanged. + /// + /// `original_message_type` is used for error classification only + /// (PreKey vs Whisper). + pub(crate) fn decrypt( + &mut self, + sender_address: &ProtocolAddress, + recipient_address: Option<&ProtocolAddress>, + ciphertext: &SignalMessage, + original_message_type: CiphertextMessageType, + current_or_previous_for_logging: CurrentOrPrevious, + csprng: &mut R, + ) -> Result> { + // DR: ensure we have a receiver chain, then consume the message key + let their_ephemeral = ciphertext.sender_ratchet_key(); + let counter = ciphertext.counter(); + let chain_key = self + .ratchet + .ensure_receiver_chain(their_ephemeral, csprng)?; + let message_key_gen = self.ratchet.consume_message_key( + their_ephemeral, + chain_key, + counter, + original_message_type, + &sender_address.to_string(), + )?; + + // SPQR recv — compute key but don't commit state yet + let spqr::Recv { + state: new_pqr_state, + key: pqr_key, + } = spqr::recv(&self.pqr_state, ciphertext.pq_ratchet()).map_err(|e| match e { + spqr::Error::StateDecode => SignalProtocolError::InvalidState( + "decrypt", + format!("post-quantum ratchet error: {e}"), + ), + _ => { + log::info!("post-quantum ratchet error in decrypt: {e}"); + SignalProtocolError::InvalidMessage( + original_message_type, + "post-quantum ratchet error", + ) + } + })?; + + // Derive final message keys by mixing DR chain key with SPQR key + let message_keys = message_key_gen.generate_keys(pqr_key); + + // MAC verification + let mac_valid = match recipient_address { + Some(recipient_address) => ciphertext.verify_mac_with_addresses( + sender_address, + recipient_address, + &self.remote_identity_key, + &self.local_identity_key, + message_keys.mac_key(), + )?, + None => ciphertext.verify_mac( + &self.remote_identity_key, + &self.local_identity_key, + message_keys.mac_key(), + )?, + }; + if !mac_valid { + return Err(SignalProtocolError::InvalidMessage( + original_message_type, + "MAC verification failed", + )); + } + + // AES-CBC decrypt + let ptext = match signal_crypto::aes_256_cbc_decrypt( + ciphertext.body(), + message_keys.cipher_key(), + message_keys.iv(), + ) { + Ok(ptext) => ptext, + Err(signal_crypto::DecryptionError::BadKeyOrIv) => { + log::warn!( + "{current_or_previous_for_logging} session state corrupt for {sender_address}", + ); + return Err(SignalProtocolError::InvalidSessionStructure( + "invalid receiver chain message keys", + )); + } + Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => { + log::warn!("failed to decrypt 1:1 message: {msg}"); + return Err(SignalProtocolError::InvalidMessage( + original_message_type, + "failed to decrypt", + )); + } + }; + + // Commit SPQR state only after all verification passed + self.pqr_state = new_pqr_state; + + Ok(ptext) + } +} diff --git a/rust/protocol/tests/ratchet.rs b/rust/protocol/tests/ratchet.rs index 9bf55a059..b02cd884d 100644 --- a/rust/protocol/tests/ratchet.rs +++ b/rust/protocol/tests/ratchet.rs @@ -51,13 +51,12 @@ fn test_alice_and_bob_agree_on_chain_keys_with_kyber() -> Result<(), SignalProto bob_identity_key_pair, bob_signed_pre_key_pair, None, - bob_ephemeral_key_pair, bob_kyber_pre_key_pair, *alice_identity_key_pair.identity_key(), alice_base_key_pair.public_key, &kyber_ciphertext, ); - let bob_record = initialize_bob_session_record(&bob_parameters)?; + let bob_record = initialize_bob_session_record(&bob_parameters, &bob_ephemeral_key_pair)?; assert_eq!( KYBER_AWARE_MESSAGE_VERSION, @@ -132,14 +131,13 @@ fn test_bob_rejects_torsioned_basekey() -> Result<(), SignalProtocolError> { bob_identity_key_pair, bob_signed_pre_key_pair, None, - bob_ephemeral_key_pair, bob_kyber_pre_key_pair, *alice_identity_key_pair.identity_key(), tweaked_alice_base_key, &kyber_ciphertext, ); - assert!(initialize_bob_session_record(&bob_parameters).is_err()); + assert!(initialize_bob_session_record(&bob_parameters, &bob_ephemeral_key_pair).is_err()); Ok(()) } @@ -194,14 +192,13 @@ fn test_bob_rejects_highbit_basekey() -> Result<(), SignalProtocolError> { bob_identity_key_pair, bob_signed_pre_key_pair, None, - bob_ephemeral_key_pair, bob_kyber_pre_key_pair, *alice_identity_key_pair.identity_key(), tweaked_alice_base_key, &kyber_ciphertext, ); - assert!(initialize_bob_session_record(&bob_parameters).is_err()); + assert!(initialize_bob_session_record(&bob_parameters, &bob_ephemeral_key_pair).is_err()); Ok(()) } diff --git a/rust/protocol/tests/support/mod.rs b/rust/protocol/tests/support/mod.rs index 864793c27..2691357dd 100644 --- a/rust/protocol/tests/support/mod.rs +++ b/rust/protocol/tests/support/mod.rs @@ -177,14 +177,13 @@ pub fn initialize_sessions_v4() -> Result<(SessionRecord, SessionRecord), Signal bob_identity, bob_base_key, None, - bob_ephemeral_key, bob_kyber_key, *alice_identity.identity_key(), alice_base_key.public_key, &kyber_ciphertext, ); - let bob_session = initialize_bob_session_record(&bob_params)?; + let bob_session = initialize_bob_session_record(&bob_params, &bob_ephemeral_key)?; Ok((alice_session, bob_session)) }