diff --git a/rust/protocol/src/proto/storage.proto b/rust/protocol/src/proto/storage.proto index b67d2f3b0..55dc91385 100644 --- a/rust/protocol/src/proto/storage.proto +++ b/rust/protocol/src/proto/storage.proto @@ -27,6 +27,10 @@ message SessionStructure { } repeated MessageKey message_keys = 4; + + // This is inside the sender chain so it automatically gets cleared on ratchet. + // It should never be set on a receiver chain. + bool needs_pni_signature = 5; } message PendingPreKey { diff --git a/rust/protocol/src/ratchet.rs b/rust/protocol/src/ratchet.rs index 334dc711e..f6f0da15f 100644 --- a/rust/protocol/src/ratchet.rs +++ b/rust/protocol/src/ratchet.rs @@ -31,6 +31,7 @@ fn derive_keys(secret_input: &[u8]) -> Result<(RootKey, ChainKey)> { pub(crate) fn initialize_alice_session( parameters: &AliceSignalProtocolParameters, + mut csprng: &mut R, ) -> Result { let local_identity = parameters.our_identity_key_pair().identity_key(); diff --git a/rust/protocol/src/state/session.rs b/rust/protocol/src/state/session.rs index 3fa5fc7d8..b1afbead2 100644 --- a/rust/protocol/src/state/session.rs +++ b/rust/protocol/src/state/session.rs @@ -216,6 +216,7 @@ impl SessionState { sender_ratchet_key_private: vec![], chain_key: Some(chain_key), message_keys: vec![], + needs_pni_signature: false, }; self.session.receiver_chains.push(chain); @@ -248,6 +249,7 @@ impl SessionState { sender_ratchet_key_private: sender.private_key.serialize().to_vec(), chain_key: Some(chain_key), message_keys: vec![], + needs_pni_signature: false, }; self.session.sender_chain = Some(new_chain); @@ -289,6 +291,7 @@ impl SessionState { sender_ratchet_key_private: vec![], chain_key: Some(chain_key), message_keys: vec![], + needs_pni_signature: false, }, Some(mut c) => { c.chain_key = Some(chain_key); @@ -446,6 +449,24 @@ impl SessionState { pub(crate) fn local_registration_id(&self) -> Result { Ok(self.session.local_registration_id) } + + pub(crate) fn needs_pni_signature(&self) -> bool { + self.session + .sender_chain + .as_ref() + .map_or(false, |chain| chain.needs_pni_signature) + } + + pub(crate) fn set_needs_pni_signature(&mut self, needs_pni_signature: bool) -> Result<()> { + let chain = &mut self.session.sender_chain.as_mut().ok_or_else(|| { + SignalProtocolError::InvalidState( + "set_needs_pni_signature", + "No sender chain".to_string(), + ) + })?; + chain.needs_pni_signature = needs_pni_signature; + Ok(()) + } } impl From for SessionState { @@ -633,6 +654,15 @@ impl SessionRecord { } } + pub fn needs_pni_signature(&self) -> Result { + Ok(self.session_state()?.needs_pni_signature()) + } + + pub fn set_needs_pni_signature(&mut self, needs_pni_signature: bool) -> Result<()> { + self.session_state_mut()? + .set_needs_pni_signature(needs_pni_signature) + } + pub fn alice_base_key(&self) -> Result<&[u8]> { self.session_state()?.alice_base_key() } diff --git a/rust/protocol/tests/session.rs b/rust/protocol/tests/session.rs index a389cc5f8..2fb2c48f6 100644 --- a/rust/protocol/tests/session.rs +++ b/rust/protocol/tests/session.rs @@ -68,6 +68,18 @@ fn test_basic_prekey_v3() -> Result<(), SignalProtocolError> { .session_version()?, 3 ); + assert!(alice_store + .load_session(&bob_address, None) + .await? + .expect("session found") + .has_sender_chain() + .expect("can ask about sender chain")); + assert!(!alice_store + .load_session(&bob_address, None) + .await? + .expect("session found") + .needs_pni_signature() + .expect("has current state")); let original_message = "L'homme est condamné à être libre"; @@ -121,6 +133,12 @@ fn test_basic_prekey_v3() -> Result<(), SignalProtocolError> { .expect("session found"); assert_eq!(bobs_session_with_alice.session_version()?, 3); assert_eq!(bobs_session_with_alice.alice_base_key()?.len(), 32 + 1); + assert!(bobs_session_with_alice + .has_sender_chain() + .expect("can ask about sender chain")); + assert!(!bobs_session_with_alice + .needs_pni_signature() + .expect("has current state")); let bob_outgoing = encrypt(&mut bob_store, &alice_address, bobs_response).await?; @@ -1988,3 +2006,81 @@ fn simultaneous_initiate_lost_message_repeated_messages() -> Result<(), SignalPr .now_or_never() .expect("sync") } + +#[test] +fn test_needs_pni_signature() -> Result<(), SignalProtocolError> { + async { + let mut csprng = OsRng; + + let alice_address = ProtocolAddress::new("+14151111111".to_owned(), 1); + let bob_address = ProtocolAddress::new("+14151111112".to_owned(), 1); + + let mut alice_store = support::test_in_memory_protocol_store()?; + let mut bob_store = support::test_in_memory_protocol_store()?; + + let bob_pre_key_bundle = create_pre_key_bundle(&mut bob_store, &mut csprng).await?; + + process_prekey_bundle( + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &bob_pre_key_bundle, + &mut csprng, + None, + ) + .await?; + + // Not set by default. + let mut alice_session_with_bob = alice_store + .load_session(&bob_address, None) + .await? + .expect("session found"); + assert!(!alice_session_with_bob + .needs_pni_signature() + .expect("has current session")); + + alice_session_with_bob + .set_needs_pni_signature(true) + .expect("has current session"); + + assert!(alice_session_with_bob + .needs_pni_signature() + .expect("has current session")); + + alice_store + .store_session(&bob_address, &alice_session_with_bob, None) + .await?; + + // Sending a message doesn't clear the state... + let message = encrypt(&mut alice_store, &bob_address, "SYN").await?; + + assert!(alice_store + .load_session(&bob_address, None) + .await? + .expect("session found") + .needs_pni_signature() + .expect("has current session")); + + // ...but receiving one does. + let _ = decrypt(&mut bob_store, &alice_address, &message).await?; + + let reply = encrypt(&mut bob_store, &alice_address, "ACK").await?; + let _ = decrypt(&mut alice_store, &bob_address, &reply).await?; + + let mut alice_session_with_bob = alice_store + .load_session(&bob_address, None) + .await? + .expect("session found"); + assert!(!alice_session_with_bob + .needs_pni_signature() + .expect("has current session")); + + // If you archive the session, you don't get to ask. + alice_session_with_bob.archive_current_state()?; + assert!(alice_session_with_bob.needs_pni_signature().is_err()); + + Ok(()) + } + .now_or_never() + .expect("sync") +}