Files
libsignal/rust/protocol/src/session_cipher.rs
Jack Lloyd e8b4474cb9 Fix handling when attempting to decrypt with a session that isn't found
There were two discrepancies between the logic here and the original
logic of libsignal-protocol-java.

First, if the session record had an uninitialized active session, in
Java this would still attempt decryption with the old session states,
but Rust would stop immediately without trying the old states. [I am
not sure if this ever happens but it could possibly occur due to use
of archiveCurrentState]

Secondly, we returned the wrong error condition. We treated lack of a
sender chain as an invalid state (effectively an internal error) but
Java treats it as an invalid message, which makes sense in so far as
it is a message which we are unable to process with the information we
have available. This wrong error type led to an unexpected exception
being thrown in Android.
2021-01-07 14:17:17 -05:00

408 lines
12 KiB
Rust

//
// Copyright 2020 Signal Messenger, LLC.
// SPDX-License-Identifier: AGPL-3.0-only
//
use crate::{
Context, IdentityKeyStore, PreKeyStore, ProtocolAddress, SessionRecord, SessionStore,
SignalProtocolError, SignedPreKeyStore,
};
use crate::consts::MAX_FORWARD_JUMPS;
use crate::crypto;
use crate::curve;
use crate::error::Result;
use crate::protocol::{CiphertextMessage, PreKeySignalMessage, SignalMessage};
use crate::ratchet::{ChainKey, MessageKeys};
use crate::session;
use crate::state::SessionState;
use crate::storage::Direction;
use rand::{CryptoRng, Rng};
pub async fn message_encrypt(
ptext: &[u8],
remote_address: &ProtocolAddress,
session_store: &mut dyn SessionStore,
identity_store: &mut dyn IdentityKeyStore,
ctx: Context,
) -> Result<CiphertextMessage> {
let mut session_record = session_store
.load_session(&remote_address, ctx)
.await?
.ok_or(SignalProtocolError::SessionNotFound)?;
let session_state = session_record.session_state_mut()?;
let chain_key = session_state.get_sender_chain_key()?;
let message_keys = chain_key.message_keys()?;
let sender_ephemeral = session_state.sender_ratchet_key()?;
let previous_counter = session_state.previous_counter()?;
let session_version = session_state.session_version()? as u8;
let local_identity_key = session_state.local_identity_key()?;
let their_identity_key = session_state
.remote_identity_key()?
.ok_or(SignalProtocolError::InvalidSessionStructure)?;
let ctext = crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())?;
let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
let local_registration_id = session_state.local_registration_id()?;
let message = SignalMessage::new(
session_version,
message_keys.mac_key(),
sender_ephemeral,
chain_key.index(),
previous_counter,
&ctext,
&local_identity_key,
&their_identity_key,
)?;
CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
session_version,
local_registration_id,
items.pre_key_id()?,
items.signed_pre_key_id()?,
*items.base_key()?,
local_identity_key,
message,
)?)
} else {
CiphertextMessage::SignalMessage(SignalMessage::new(
session_version,
message_keys.mac_key(),
sender_ephemeral,
chain_key.index(),
previous_counter,
&ctext,
&local_identity_key,
&their_identity_key,
)?)
};
session_state.set_sender_chain_key(&chain_key.next_chain_key()?)?;
// XXX why is this check after everything else?!!
if !identity_store
.is_trusted_identity(
&remote_address,
&their_identity_key,
Direction::Sending,
ctx,
)
.await?
{
return Err(SignalProtocolError::UntrustedIdentity(
remote_address.clone(),
));
}
// XXX this could be combined with the above call to the identity store (in a new API)
identity_store
.save_identity(&remote_address, &their_identity_key, ctx)
.await?;
session_store
.store_session(&remote_address, &session_record, ctx)
.await?;
Ok(message)
}
pub async fn message_decrypt<R: Rng + CryptoRng>(
ciphertext: &CiphertextMessage,
remote_address: &ProtocolAddress,
session_store: &mut dyn SessionStore,
identity_store: &mut dyn IdentityKeyStore,
pre_key_store: &mut dyn PreKeyStore,
signed_pre_key_store: &mut dyn SignedPreKeyStore,
csprng: &mut R,
ctx: Context,
) -> Result<Vec<u8>> {
match ciphertext {
CiphertextMessage::SignalMessage(m) => {
message_decrypt_signal(
m,
remote_address,
session_store,
identity_store,
csprng,
ctx,
)
.await
}
CiphertextMessage::PreKeySignalMessage(m) => {
message_decrypt_prekey(
m,
remote_address,
session_store,
identity_store,
pre_key_store,
signed_pre_key_store,
csprng,
ctx,
)
.await
}
_ => Err(SignalProtocolError::InvalidArgument(
"SessionCipher::decrypt cannot decrypt this message type".to_owned(),
)),
}
}
pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
ciphertext: &PreKeySignalMessage,
remote_address: &ProtocolAddress,
session_store: &mut dyn SessionStore,
identity_store: &mut dyn IdentityKeyStore,
pre_key_store: &mut dyn PreKeyStore,
signed_pre_key_store: &mut dyn SignedPreKeyStore,
csprng: &mut R,
ctx: Context,
) -> Result<Vec<u8>> {
let mut session_record = session_store
.load_session(&remote_address, ctx)
.await?
.unwrap_or_else(SessionRecord::new_fresh);
let pre_key_id = session::process_prekey(
ciphertext,
&remote_address,
&mut session_record,
identity_store,
pre_key_store,
signed_pre_key_store,
ctx,
)
.await?;
let ptext = decrypt_message_with_record(&mut session_record, ciphertext.message(), csprng)?;
session_store
.store_session(&remote_address, &session_record, ctx)
.await?;
if let Some(pre_key_id) = pre_key_id {
pre_key_store.remove_pre_key(pre_key_id, ctx).await?;
}
Ok(ptext)
}
pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
ciphertext: &SignalMessage,
remote_address: &ProtocolAddress,
session_store: &mut dyn SessionStore,
identity_store: &mut dyn IdentityKeyStore,
csprng: &mut R,
ctx: Context,
) -> Result<Vec<u8>> {
let mut session_record = session_store
.load_session(&remote_address, ctx)
.await?
.ok_or(SignalProtocolError::SessionNotFound)?;
let ptext = decrypt_message_with_record(&mut session_record, ciphertext, csprng)?;
// Why are we performing this check after decryption instead of before?
let their_identity_key = session_record
.session_state()?
.remote_identity_key()?
.ok_or(SignalProtocolError::InvalidSessionStructure)?;
if !identity_store
.is_trusted_identity(
&remote_address,
&their_identity_key,
Direction::Receiving,
ctx,
)
.await?
{
return Err(SignalProtocolError::UntrustedIdentity(
remote_address.clone(),
));
}
identity_store
.save_identity(&remote_address, &their_identity_key, ctx)
.await?;
session_store
.store_session(&remote_address, &session_record, ctx)
.await?;
Ok(ptext)
}
fn decrypt_message_with_record<R: Rng + CryptoRng>(
record: &mut SessionRecord,
ciphertext: &SignalMessage,
csprng: &mut R,
) -> Result<Vec<u8>> {
if let Ok(current_state) = record.session_state() {
let mut current_state = current_state.clone();
let result = decrypt_message_with_state(&mut current_state, ciphertext, csprng);
match result {
Ok(ptext) => {
record.set_session_state(current_state)?; // update the state
return Ok(ptext);
}
Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
return result;
}
Err(_) => {}
}
}
// Try some old sessions:
let mut updated_session = None;
for (idx, previous) in record.previous_session_states()?.enumerate() {
let mut updated = previous.clone();
let result = decrypt_message_with_state(&mut updated, ciphertext, csprng);
match result {
Ok(ptext) => {
updated_session = Some((ptext, idx, updated));
break;
}
Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
return result;
}
_ => {}
}
}
if let Some((ptext, idx, updated_session)) = updated_session {
record.promote_old_session(idx, updated_session)?;
Ok(ptext)
} else {
Err(SignalProtocolError::InvalidMessage(
"decryption failed; no matching session state",
))
}
}
fn decrypt_message_with_state<R: Rng + CryptoRng>(
state: &mut SessionState,
ciphertext: &SignalMessage,
csprng: &mut R,
) -> Result<Vec<u8>> {
if !state.has_sender_chain()? {
return Err(SignalProtocolError::InvalidMessage(
"No session available to decrypt",
));
}
let ciphertext_version = ciphertext.message_version() as u32;
if ciphertext_version != state.session_version()? {
return Err(SignalProtocolError::UnrecognizedMessageVersion(
ciphertext_version,
));
}
let their_ephemeral = ciphertext.sender_ratchet_key();
let counter = ciphertext.counter();
let chain_key = get_or_create_chain_key(state, their_ephemeral, csprng)?;
let message_keys = get_or_create_message_key(state, their_ephemeral, &chain_key, counter)?;
let their_identity_key = state
.remote_identity_key()?
.ok_or(SignalProtocolError::InvalidSessionStructure)?;
let mac_valid = ciphertext.verify_mac(
&their_identity_key,
&state.local_identity_key()?,
message_keys.mac_key(),
)?;
if !mac_valid {
return Err(SignalProtocolError::InvalidCiphertext);
}
let ptext = crypto::aes_256_cbc_decrypt(
ciphertext.body(),
message_keys.cipher_key(),
message_keys.iv(),
)?;
state.clear_unacknowledged_pre_key_message()?;
Ok(ptext)
}
fn get_or_create_chain_key<R: Rng + CryptoRng>(
state: &mut SessionState,
their_ephemeral: &curve::PublicKey,
csprng: &mut R,
) -> Result<ChainKey> {
if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
return Ok(chain);
}
let root_key = state.root_key()?;
let our_ephemeral = state.sender_ratchet_private_key()?;
let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
let our_new_ephemeral = curve::KeyPair::generate(csprng);
let sender_chain = receiver_chain
.0
.create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
state.set_root_key(&sender_chain.0)?;
state.add_receiver_chain(their_ephemeral, &receiver_chain.1)?;
let current_index = state.get_sender_chain_key()?.index();
let previous_index = if current_index > 0 {
current_index - 1
} else {
0
};
state.set_previous_counter(previous_index)?;
state.set_sender_chain(&our_new_ephemeral, &sender_chain.1)?;
Ok(receiver_chain.1)
}
fn get_or_create_message_key(
state: &mut SessionState,
their_ephemeral: &curve::PublicKey,
chain_key: &ChainKey,
counter: u32,
) -> Result<MessageKeys> {
let chain_index = chain_key.index();
if chain_index > counter {
return match state.get_message_keys(their_ephemeral, counter)? {
Some(keys) => Ok(keys),
None => Err(SignalProtocolError::DuplicatedMessage(chain_index, counter)),
};
}
assert!(chain_index <= counter);
let jump = (counter - chain_index) as usize;
if jump > MAX_FORWARD_JUMPS {
return Err(SignalProtocolError::InvalidMessage(
"message from too far into the future",
));
}
let mut chain_key = chain_key.clone();
while chain_key.index() < counter {
let message_keys = chain_key.message_keys()?;
state.set_message_keys(their_ephemeral, &message_keys)?;
chain_key = chain_key.next_chain_key()?;
}
state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key()?)?;
Ok(chain_key.message_keys()?)
}