diff --git a/Cargo.lock b/Cargo.lock index e4612c890..db5f481c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2514,14 +2514,14 @@ dependencies = [ [[package]] name = "libsignal-debug" -version = "0.90.1" +version = "0.91.0" dependencies = [ "cfg-if", ] [[package]] name = "libsignal-ffi" -version = "0.90.1" +version = "0.91.0" dependencies = [ "cpufeatures", "hex", @@ -2542,7 +2542,7 @@ dependencies = [ [[package]] name = "libsignal-jni" -version = "0.90.1" +version = "0.91.0" dependencies = [ "libsignal-debug", "libsignal-jni-impl", @@ -2550,7 +2550,7 @@ dependencies = [ [[package]] name = "libsignal-jni-impl" -version = "0.90.1" +version = "0.91.0" dependencies = [ "cfg-if", "cpufeatures", @@ -2567,7 +2567,7 @@ dependencies = [ [[package]] name = "libsignal-jni-testing" -version = "0.90.1" +version = "0.91.0" dependencies = [ "jni", "libsignal-bridge-testing", @@ -2882,7 +2882,7 @@ dependencies = [ [[package]] name = "libsignal-node" -version = "0.90.1" +version = "0.91.0" dependencies = [ "futures", "libsignal-bridge", diff --git a/Cargo.toml b/Cargo.toml index 20ba7d2ba..ceae4ccd3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ default-members = [ resolver = "2" # so that our dev-dependency features don't leak into products [workspace.package] -version = "0.90.1" +version = "0.91.0" authors = ["Signal Messenger LLC"] license = "AGPL-3.0-only" rust-version = "1.88" diff --git a/LibSignalClient.podspec b/LibSignalClient.podspec index f427bd5f4..d8f050518 100644 --- a/LibSignalClient.podspec +++ b/LibSignalClient.podspec @@ -5,7 +5,7 @@ Pod::Spec.new do |s| s.name = 'LibSignalClient' - s.version = '0.90.1' + s.version = '0.91.0' s.summary = 'A Swift wrapper library for communicating with the Signal messaging service.' s.homepage = 'https://github.com/signalapp/libsignal' diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 123c5db87..c45178523 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,4 +1,5 @@ -v0.90.1 +v0.91.0 - Support gRPC for getUploadForm() +- 1:1 message decryption now takes the local address as an extra argument diff --git a/java/build.gradle b/java/build.gradle index 85e3c7eae..b595d9a2b 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -23,7 +23,7 @@ repositories { } allprojects { - version = "0.90.1" + version = "0.91.0" group = "org.signal" tasks.withType(JavaCompile) { diff --git a/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java b/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java index 6ef4df317..c151208ed 100644 --- a/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java +++ b/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java @@ -264,7 +264,9 @@ public class SealedSessionCipher { .decrypt(new SignalMessage(message.getContent())); case CiphertextMessage.PREKEY_TYPE: return new SessionCipher(signalProtocolStore, sender) - .decrypt(new PreKeySignalMessage(message.getContent())); + .decrypt( + new PreKeySignalMessage(message.getContent()), + new SignalProtocolAddress(localUuidAddress, localDeviceId)); case CiphertextMessage.SENDERKEY_TYPE: return new GroupCipher(signalProtocolStore, sender).decrypt(message.getContent()); case CiphertextMessage.PLAINTEXT_CONTENT_TYPE: diff --git a/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java b/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java index 444fe9ed7..f97f6f63a 100644 --- a/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java +++ b/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java @@ -86,7 +86,7 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessage.serialize()); SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -142,7 +142,7 @@ public class SessionBuilderTest { var bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize()), BOB_ADDRESS); fail("shouldn't be trusted!"); } catch (UntrustedIdentityException uie) { bobStore.saveIdentity( @@ -150,7 +150,8 @@ public class SessionBuilderTest { } var plaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize())); + bobSessionCipher.decrypt( + new PreKeySignalMessage(outgoingMessage.serialize()), BOB_ADDRESS); assertTrue(new String(plaintext).equals(originalMessage)); Random random = new Random(); @@ -206,7 +207,7 @@ public class SessionBuilderTest { SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); assertTrue(originalMessage.equals(new String(plaintext))); CiphertextMessage bobOutgoingMessage = bobSessionCipher.encrypt(originalMessage.getBytes()); @@ -220,7 +221,9 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessageTwo = new PreKeySignalMessage(outgoingMessageTwo.serialize()); - plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(incomingMessageTwo.serialize())); + plaintext = + bobSessionCipher.decrypt( + new PreKeySignalMessage(incomingMessageTwo.serialize()), BOB_ADDRESS); assertTrue(originalMessage.equals(new String(plaintext))); bobOutgoingMessage = bobSessionCipher.encrypt(originalMessage.getBytes()); @@ -267,7 +270,7 @@ public class SessionBuilderTest { assertTrue(!incomingMessage.getPreKeyId().isPresent()); SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -356,14 +359,15 @@ public class SessionBuilderTest { assertTrue(!incomingMessage.getPreKeyId().isPresent()); SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - bobSessionCipher.decrypt(incomingMessage); + bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); SessionCipher bobSessionCipherForMallory = new SessionCipher(bobStore, MALLORY_ADDRESS); assertThrows( - ReusedBaseKeyException.class, () -> bobSessionCipherForMallory.decrypt(incomingMessage)); + ReusedBaseKeyException.class, + () -> bobSessionCipherForMallory.decrypt(incomingMessage, BOB_ADDRESS)); } } @@ -504,7 +508,7 @@ public class SessionBuilderTest { byte[] plaintext = new byte[0]; try { - plaintext = bobSessionCipher.decrypt(incomingMessage); + plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); fail("Decrypt should have failed!"); } catch (InvalidMessageException e) { // good. @@ -512,7 +516,7 @@ public class SessionBuilderTest { assertTrue(bobStore.containsPreKey(bobPreKey.getPreKeyId())); - plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(goodMessage)); + plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(goodMessage), BOB_ADDRESS); assertTrue(originalMessage.equals(new String(plaintext))); assertFalse(bobStore.containsPreKey(bobPreKey.getPreKeyId())); @@ -546,7 +550,7 @@ public class SessionBuilderTest { SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(incomingMessage); + bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); fail("Decrypt should have failed!"); } catch (InvalidKeyIdException e) { assertEquals( @@ -583,7 +587,7 @@ public class SessionBuilderTest { SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(incomingMessage); + bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); fail("Decrypt should have failed!"); } catch (InvalidKeyIdException e) { fail("libsignal swallowed the exception"); diff --git a/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java b/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java index 01b594d41..cf512f4c5 100644 --- a/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java +++ b/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java @@ -79,9 +79,10 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); + aliceSessionCipher.decrypt( + new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -145,7 +146,7 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(bobPlaintext).equals("hey there")); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -155,7 +156,7 @@ public class SimultaneousInitiateTests { assertEquals(aliceResponse.getType(), CiphertextMessage.PREKEY_TYPE); byte[] responsePlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(aliceResponse.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(aliceResponse.serialize()), BOB_ADDRESS); assertTrue(new String(responsePlaintext).equals("second message")); assertSessionIdEquals(aliceStore, bobStore); @@ -204,9 +205,10 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); + aliceSessionCipher.decrypt( + new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -266,9 +268,10 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); + aliceSessionCipher.decrypt( + new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -354,9 +357,10 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); + aliceSessionCipher.decrypt( + new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -449,9 +453,10 @@ public class SimultaneousInitiateTests { assertFalse(isSessionIdEqual(aliceStore, bobStore)); byte[] alicePlaintext = - aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); + aliceSessionCipher.decrypt( + new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -503,7 +508,8 @@ public class SimultaneousInitiateTests { assertTrue(isSessionIdEqual(aliceStore, bobStore)); byte[] lostMessagePlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(lostMessageForBob.serialize())); + bobSessionCipher.decrypt( + new PreKeySignalMessage(lostMessageForBob.serialize()), BOB_ADDRESS); assertTrue(new String(lostMessagePlaintext).equals("hey there")); assertFalse(isSessionIdEqual(aliceStore, bobStore)); diff --git a/java/shared/java/org/signal/libsignal/internal/Native.kt b/java/shared/java/org/signal/libsignal/internal/Native.kt index a20132a25..ea95ccfbf 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.kt +++ b/java/shared/java/org/signal/libsignal/internal/Native.kt @@ -1180,7 +1180,7 @@ internal object Native { public external fun SessionBuilder_ProcessPreKeyBundle(bundle: ObjectHandle, protocolAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Long): Unit @JvmStatic @Throws(Exception::class) - public external fun SessionCipher_DecryptPreKeySignalMessage(message: ObjectHandle, protocolAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, prekeyStore: PreKeyStore, signedPrekeyStore: SignedPreKeyStore, kyberPrekeyStore: KyberPreKeyStore): ByteArray + public external fun SessionCipher_DecryptPreKeySignalMessage(message: ObjectHandle, protocolAddress: ObjectHandle, localAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, prekeyStore: PreKeyStore, signedPrekeyStore: SignedPreKeyStore, kyberPrekeyStore: KyberPreKeyStore): ByteArray @JvmStatic @Throws(Exception::class) public external fun SessionCipher_DecryptSignalMessage(message: ObjectHandle, protocolAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore): ByteArray @JvmStatic @Throws(Exception::class) diff --git a/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java b/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java index 70e9cf2c2..e740e62dc 100644 --- a/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java +++ b/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java @@ -181,14 +181,15 @@ public class SessionCipher { * @throws InvalidKeyException when the message is formatted incorrectly. * @throws UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted. */ - public byte[] decrypt(PreKeySignalMessage ciphertext) + public byte[] decrypt(PreKeySignalMessage ciphertext, SignalProtocolAddress localAddress) throws DuplicateMessageException, InvalidMessageException, InvalidKeyIdException, InvalidKeyException, UntrustedIdentityException { try (NativeHandleGuard ciphertextGuard = new NativeHandleGuard(ciphertext); - NativeHandleGuard remoteAddressGuard = new NativeHandleGuard(this.remoteAddress); ) { + NativeHandleGuard remoteAddressGuard = new NativeHandleGuard(this.remoteAddress); + NativeHandleGuard localAddressGuard = new NativeHandleGuard(localAddress); ) { return filterExceptions( DuplicateMessageException.class, InvalidMessageException.class, @@ -199,6 +200,7 @@ public class SessionCipher { Native.SessionCipher_DecryptPreKeySignalMessage( ciphertextGuard.nativeHandle(), remoteAddressGuard.nativeHandle(), + localAddressGuard.nativeHandle(), bridge(sessionStore), _bridge(identityKeyStore), new org.signal.libsignal.protocol.state.internal.PreKeyStore() { diff --git a/node/package-lock.json b/node/package-lock.json index 281c77a9e..234409aba 100644 --- a/node/package-lock.json +++ b/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "@signalapp/libsignal-client", - "version": "0.90.1", + "version": "0.91.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@signalapp/libsignal-client", - "version": "0.90.1", + "version": "0.91.0", "hasInstallScript": true, "license": "AGPL-3.0-only", "dependencies": { @@ -874,7 +874,6 @@ "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.44.0.tgz", "integrity": "sha512-VGMpFQGUQWYT9LfnPcX8ouFojyrZ/2w3K5BucvxL/spdNehccKhB4jUyB1yBCXpr2XFm0jkECxgrpXBW2ipoAw==", "dev": true, - "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.44.0", "@typescript-eslint/types": "8.44.0", @@ -1072,7 +1071,6 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", "integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==", "dev": true, - "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -1475,7 +1473,6 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-6.0.1.tgz", "integrity": "sha512-/JOoU2//6p5vCXh00FpNgtlw0LjvhGttaWc+y7wpW9yjBm3ys0dI8tSKZxIOgNruz5J0RleccatSIC3uxEZP0g==", "dev": true, - "peer": true, "engines": { "node": ">=18" } @@ -2071,7 +2068,6 @@ "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.35.0.tgz", "integrity": "sha512-QePbBFMJFjgmlE+cXAlbHZbHpdFVS2E/6vzCy7aKlebddvl1vadiC4JFV5u/wqTkNUwEV8WrQi257jf5f06hrg==", "dev": true, - "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", @@ -2231,7 +2227,6 @@ "resolved": "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.32.0.tgz", "integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==", "dev": true, - "peer": true, "dependencies": { "@rtsao/scc": "^1.1.0", "array-includes": "^3.1.9", @@ -4402,7 +4397,6 @@ "resolved": "https://registry.npmjs.org/prettier/-/prettier-2.8.8.tgz", "integrity": "sha512-tdN8qQGvNjw4CHbY+XXk0JgCXn9QiF21a55rBe5LJAU+kDyC4WQn4+awm2Xfk2lQMk5fKup9XgzTZtGkjBdP9Q==", "dev": true, - "peer": true, "bin": { "prettier": "bin-prettier.js" }, @@ -5126,7 +5120,6 @@ "resolved": "https://registry.npmjs.org/sinon/-/sinon-21.0.0.tgz", "integrity": "sha512-TOgRcwFPbfGtpqvZw+hyqJDvqfapr1qUlOizROIk4bBLjlsjlB00Pg6wMFXNtJRpu+eCZuVOaLatG7M8105kAw==", "dev": true, - "peer": true, "dependencies": { "@sinonjs/commons": "^3.0.1", "@sinonjs/fake-timers": "^13.0.5", @@ -5550,7 +5543,6 @@ "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-4.0.3.tgz", "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, - "peer": true, "engines": { "node": ">=12" }, @@ -5734,7 +5726,6 @@ "integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==", "dev": true, "license": "Apache-2.0", - "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" diff --git a/node/package.json b/node/package.json index e6842700b..428f72451 100644 --- a/node/package.json +++ b/node/package.json @@ -1,6 +1,6 @@ { "name": "@signalapp/libsignal-client", - "version": "0.90.1", + "version": "0.91.0", "repository": "github:signalapp/libsignal", "license": "AGPL-3.0-only", "type": "module", diff --git a/node/ts/Native.ts b/node/ts/Native.ts index 41063c346..d7db4805a 100644 --- a/node/ts/Native.ts +++ b/node/ts/Native.ts @@ -311,7 +311,7 @@ type NativeFunctions = { SessionBuilder_ProcessPreKeyBundle: (bundle: Wrapper, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Timestamp) => Promise; SessionCipher_EncryptMessage: (ptext: Uint8Array, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Timestamp) => Promise; SessionCipher_DecryptSignalMessage: (message: Wrapper, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore) => Promise>; - SessionCipher_DecryptPreKeySignalMessage: (message: Wrapper, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, prekeyStore: PreKeyStore, signedPrekeyStore: SignedPreKeyStore, kyberPrekeyStore: KyberPreKeyStore) => Promise>; + SessionCipher_DecryptPreKeySignalMessage: (message: Wrapper, protocolAddress: Wrapper, localAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, prekeyStore: PreKeyStore, signedPrekeyStore: SignedPreKeyStore, kyberPrekeyStore: KyberPreKeyStore) => Promise>; SealedSender_Encrypt: (destination: Wrapper, content: Wrapper, identityKeyStore: IdentityKeyStore) => Promise>; SealedSender_MultiRecipientEncrypt: (recipients: Wrapper[], recipientSessions: Wrapper[], excludedRecipients: Uint8Array, content: Wrapper, identityKeyStore: IdentityKeyStore) => Promise>; SealedSender_MultiRecipientMessageForSingleRecipient: (encodedMultiRecipientMessage: Uint8Array) => Uint8Array; diff --git a/node/ts/index.ts b/node/ts/index.ts index 225cc3c25..96a78222b 100644 --- a/node/ts/index.ts +++ b/node/ts/index.ts @@ -1390,6 +1390,7 @@ function bridgeKyberPreKeyStore( export function signalDecryptPreKey( message: PreKeySignalMessage, address: ProtocolAddress, + localAddress: ProtocolAddress, sessionStore: SessionStore, identityStore: IdentityKeyStore, prekeyStore: PreKeyStore, @@ -1399,6 +1400,7 @@ export function signalDecryptPreKey( return Native.SessionCipher_DecryptPreKeySignalMessage( message, address, + localAddress, bridgeSessionStore(sessionStore), bridgeIdentityKeyStore(identityStore), bridgePreKeyStore(prekeyStore), diff --git a/node/ts/test/protocol/ProtocolTest.ts b/node/ts/test/protocol/ProtocolTest.ts index b92ab67b6..8634030e9 100644 --- a/node/ts/test/protocol/ProtocolTest.ts +++ b/node/ts/test/protocol/ProtocolTest.ts @@ -756,6 +756,7 @@ for (const testCase of sessionVersionTestCases) { const bDPlaintext = await SignalClient.signalDecryptPreKey( aCiphertextR, aAddress, + bAddress, bobStores.session, bobStores.identity, bobStores.prekey, @@ -852,6 +853,7 @@ for (const testCase of sessionVersionTestCases) { const bDPlaintext = await SignalClient.signalDecryptPreKey( aCiphertextR, aAddress, + bAddress, bobStores.session, bobStores.identity, bobStores.prekey, @@ -864,6 +866,7 @@ for (const testCase of sessionVersionTestCases) { await SignalClient.signalDecryptPreKey( aCiphertextR, aAddress, + bAddress, bobStores.session, bobStores.identity, bobStores.prekey, @@ -1020,6 +1023,7 @@ for (const testCase of sessionVersionTestCases) { void (await SignalClient.signalDecryptPreKey( aCiphertextR, aAddress, + bAddress, bobStores.session, bobStores.identity, bobStores.prekey, @@ -1031,6 +1035,7 @@ for (const testCase of sessionVersionTestCases) { SignalClient.signalDecryptPreKey( aCiphertextR, mAddress, + bAddress, bobStores.session, bobStores.identity, bobStores.prekey, diff --git a/rust/bridge/shared/src/protocol.rs b/rust/bridge/shared/src/protocol.rs index c15ca1436..34cee26eb 100644 --- a/rust/bridge/shared/src/protocol.rs +++ b/rust/bridge/shared/src/protocol.rs @@ -361,6 +361,7 @@ fn SignalMessage_New( SignalMessage::new( message_version, mac_key, + None, *sender_ratchet_key, counter, previous_counter, @@ -1068,6 +1069,7 @@ async fn SessionCipher_DecryptSignalMessage( async fn SessionCipher_DecryptPreKeySignalMessage( message: &PreKeySignalMessage, protocol_address: &ProtocolAddress, + local_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_key_store: &mut dyn IdentityKeyStore, prekey_store: &mut dyn PreKeyStore, @@ -1078,6 +1080,7 @@ async fn SessionCipher_DecryptPreKeySignalMessage( message_decrypt_prekey( message, protocol_address, + local_address, session_store, identity_key_store, prekey_store, diff --git a/rust/core/src/version.rs b/rust/core/src/version.rs index 00461a120..e93e1e6f9 100644 --- a/rust/core/src/version.rs +++ b/rust/core/src/version.rs @@ -5,4 +5,4 @@ // The value of this constant is updated by the script // and should not be manually modified -pub const VERSION: &str = "0.90.1"; +pub const VERSION: &str = "0.91.0"; diff --git a/rust/protocol/benches/session.rs b/rust/protocol/benches/session.rs index 2de8c151f..15163e22d 100644 --- a/rust/protocol/benches/session.rs +++ b/rust/protocol/benches/session.rs @@ -43,16 +43,26 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr c.bench_function("decrypting the first message on a chain", |b| { b.iter(|| { let mut bob_store = bob_store.clone(); - support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync") - .expect("success"); + support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync") + .expect("success"); }) }); - let _ = support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync")?; + let _ = support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync")?; let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") .now_or_never() @@ -73,10 +83,15 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr c.bench_function("decrypting on an existing chain", |b| { b.iter(|| { let mut bob_store = bob_store.clone(); - support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync") - .expect("success"); + support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync") + .expect("success"); }) }); @@ -211,17 +226,27 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr |b| { b.iter(|| { let mut bob_store = bob_store.clone(); - support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync") - .expect("success") + support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync") + .expect("success") }) }, ); - let _ = support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync")?; + let _ = support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync")?; // ...and prepare another message to benchmark decrypting. let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") .now_or_never() @@ -237,10 +262,15 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr |b| { b.iter(|| { let mut bob_store = bob_store.clone(); - support::decrypt(&mut bob_store, &alice_address, &message_to_decrypt) - .now_or_never() - .expect("sync") - .expect("success"); + support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &message_to_decrypt, + ) + .now_or_never() + .expect("sync") + .expect("success"); }) }, ); @@ -253,10 +283,15 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr |b| { b.iter(|| { let mut bob_store = bob_store.clone(); - support::decrypt(&mut bob_store, &alice_address, &original_message_to_decrypt) - .now_or_never() - .expect("sync") - .expect("success"); + support::decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &original_message_to_decrypt, + ) + .now_or_never() + .expect("sync") + .expect("success"); }) }, ); @@ -287,7 +322,7 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro .now_or_never() .expect("sync") .expect("success"); - let _ptext = support::decrypt(&mut bob_store, &alice_address, &ctext) + let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") .expect("success"); @@ -298,7 +333,7 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro .now_or_never() .expect("sync") .expect("success"); - let _ptext = support::decrypt(&mut bob_store, &alice_address, &ctext) + let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") .expect("success"); @@ -311,7 +346,7 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro .now_or_never() .expect("sync") .expect("success"); - let _ptext = support::decrypt(&mut bob_store, &alice_address, &ctext) + let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") .expect("success"); @@ -320,7 +355,7 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro .now_or_never() .expect("sync") .expect("success"); - let _ptext = support::decrypt(&mut alice_store, &bob_address, &ctext) + let _ptext = support::decrypt(&mut alice_store, &bob_address, &alice_address, &ctext) .now_or_never() .expect("sync") .expect("success"); diff --git a/rust/protocol/cross-version-testing/src/current.rs b/rust/protocol/cross-version-testing/src/current.rs index 35e68968a..88bc5f0a6 100644 --- a/rust/protocol/cross-version-testing/src/current.rs +++ b/rust/protocol/cross-version-testing/src/current.rs @@ -158,7 +158,13 @@ impl super::LibSignalProtocolStore for LibSignalProtocolCurrent { (encrypted.serialize().to_vec(), encrypted.message_type()) } - fn decrypt(&mut self, remote: &str, msg: &[u8], msg_type: CiphertextMessageType) -> Vec { + fn decrypt( + &mut self, + remote: &str, + local: &str, + msg: &[u8], + msg_type: CiphertextMessageType, + ) -> Vec { match msg_type { CiphertextMessageType::Whisper => message_decrypt_signal( &SignalMessage::try_from(msg).expect("valid"), @@ -173,6 +179,7 @@ impl super::LibSignalProtocolStore for LibSignalProtocolCurrent { CiphertextMessageType::PreKey => message_decrypt_prekey( &PreKeySignalMessage::try_from(msg).expect("valid"), &address(remote), + &address(local), &mut self.0.session_store, &mut self.0.identity_store, &mut self.0.pre_key_store, diff --git a/rust/protocol/cross-version-testing/src/lib.rs b/rust/protocol/cross-version-testing/src/lib.rs index af7489c8e..58012ae95 100644 --- a/rust/protocol/cross-version-testing/src/lib.rs +++ b/rust/protocol/cross-version-testing/src/lib.rs @@ -14,7 +14,13 @@ pub trait LibSignalProtocolStore { fn create_pre_key_bundle(&mut self) -> PreKeyBundle; fn process_pre_key_bundle(&mut self, remote: &str, pre_key_bundle: PreKeyBundle); fn encrypt(&mut self, remote: &str, msg: &[u8]) -> (Vec, CiphertextMessageType); - fn decrypt(&mut self, remote: &str, msg: &[u8], msg_type: CiphertextMessageType) -> Vec; + fn decrypt( + &mut self, + remote: &str, + local: &str, + msg: &[u8], + msg_type: CiphertextMessageType, + ) -> Vec; fn encrypt_sealed_sender_v1( &self, diff --git a/rust/protocol/cross-version-testing/src/v70.rs b/rust/protocol/cross-version-testing/src/v70.rs index a40fd776b..7400cc1a7 100644 --- a/rust/protocol/cross-version-testing/src/v70.rs +++ b/rust/protocol/cross-version-testing/src/v70.rs @@ -190,6 +190,7 @@ impl super::LibSignalProtocolStore for LibSignalProtocolV70 { fn decrypt( &mut self, remote: &str, + _local: &str, msg: &[u8], msg_type: super::CiphertextMessageType, ) -> Vec { diff --git a/rust/protocol/cross-version-testing/tests/session.rs b/rust/protocol/cross-version-testing/tests/session.rs index 5f55441f7..8415152c2 100644 --- a/rust/protocol/cross-version-testing/tests/session.rs +++ b/rust/protocol/cross-version-testing/tests/session.rs @@ -24,14 +24,20 @@ fn test_basic_prekey() { alice_store.encrypt(bob_name, original_message); assert_eq!(outgoing_message_type, CiphertextMessageType::PreKey); - let ptext = bob_store.decrypt(alice_name, &outgoing_message, outgoing_message_type); + let ptext = bob_store.decrypt( + alice_name, + bob_name, + &outgoing_message, + outgoing_message_type, + ); assert_eq!(&ptext, original_message); let bobs_response = "Who watches the watchers?".as_bytes(); let (bob_outgoing, bob_outgoing_type) = bob_store.encrypt(alice_name, bobs_response); assert_eq!(bob_outgoing_type, CiphertextMessageType::Whisper); - let alice_decrypts = alice_store.decrypt(bob_name, &bob_outgoing, bob_outgoing_type); + let alice_decrypts = + alice_store.decrypt(bob_name, alice_name, &bob_outgoing, bob_outgoing_type); assert_eq!(&alice_decrypts, bobs_response); run_interaction(alice_store, alice_name, bob_store, bob_name); @@ -49,7 +55,7 @@ fn run_interaction( let (alice_message, alice_message_type) = alice_store.encrypt(bob_name, alice_ptext); assert_eq!(alice_message_type, CiphertextMessageType::Whisper); assert_eq!( - &bob_store.decrypt(alice_name, &alice_message, alice_message_type), + &bob_store.decrypt(alice_name, bob_name, &alice_message, alice_message_type), alice_ptext ); @@ -58,7 +64,7 @@ fn run_interaction( let (bob_message, bob_message_type) = bob_store.encrypt(alice_name, bob_ptext); assert_eq!(bob_message_type, CiphertextMessageType::Whisper); assert_eq!( - &alice_store.decrypt(bob_name, &bob_message, bob_message_type), + &alice_store.decrypt(bob_name, alice_name, &bob_message, bob_message_type), bob_ptext ); @@ -68,7 +74,7 @@ fn run_interaction( alice_store.encrypt(bob_name, alice_ptext.as_bytes()); assert_eq!(alice_message_type, CiphertextMessageType::Whisper); assert_eq!( - &bob_store.decrypt(alice_name, &alice_message, alice_message_type), + &bob_store.decrypt(alice_name, bob_name, &alice_message, alice_message_type), alice_ptext.as_bytes() ); } @@ -78,7 +84,7 @@ fn run_interaction( let (bob_message, bob_message_type) = bob_store.encrypt(alice_name, bob_ptext.as_bytes()); assert_eq!(bob_message_type, CiphertextMessageType::Whisper); assert_eq!( - &alice_store.decrypt(bob_name, &bob_message, bob_message_type), + &alice_store.decrypt(bob_name, alice_name, &bob_message, bob_message_type), bob_ptext.as_bytes() ); } @@ -95,7 +101,12 @@ fn run_interaction( let alice_ptext = format!("A->B post-OOO message {}", i); let (alice_message, _) = alice_store.encrypt(bob_name, alice_ptext.as_bytes()); assert_eq!( - &bob_store.decrypt(alice_name, &alice_message, CiphertextMessageType::Whisper), + &bob_store.decrypt( + alice_name, + bob_name, + &alice_message, + CiphertextMessageType::Whisper + ), alice_ptext.as_bytes() ); } @@ -104,14 +115,19 @@ fn run_interaction( let bob_ptext = format!("B->A message post-OOO {}", i); let (bob_message, _) = bob_store.encrypt(alice_name, bob_ptext.as_bytes()); assert_eq!( - &alice_store.decrypt(bob_name, &bob_message, CiphertextMessageType::Whisper), + &alice_store.decrypt( + bob_name, + alice_name, + &bob_message, + CiphertextMessageType::Whisper + ), bob_ptext.as_bytes() ); } for (ptext, ctext) in alice_ooo_messages { assert_eq!( - &bob_store.decrypt(alice_name, &ctext, CiphertextMessageType::Whisper), + &bob_store.decrypt(alice_name, bob_name, &ctext, CiphertextMessageType::Whisper), ptext.as_bytes() ); } diff --git a/rust/protocol/fuzz/fuzz_targets/interaction.rs b/rust/protocol/fuzz/fuzz_targets/interaction.rs index b249b2000..fa7e43492 100644 --- a/rust/protocol/fuzz/fuzz_targets/interaction.rs +++ b/rust/protocol/fuzz/fuzz_targets/interaction.rs @@ -205,6 +205,7 @@ impl Participant { let decrypted = message_decrypt( &incoming_message, their_address, + &self.address, &mut self.store.session_store, &mut self.store.identity_store, &mut self.store.pre_key_store, diff --git a/rust/protocol/src/proto/wire.proto b/rust/protocol/src/proto/wire.proto index 822f6a9e2..4e2647f0e 100644 --- a/rust/protocol/src/proto/wire.proto +++ b/rust/protocol/src/proto/wire.proto @@ -8,11 +8,12 @@ syntax = "proto2"; package signal.proto.wire; message SignalMessage { - optional bytes ratchet_key = 1; - optional uint32 counter = 2; - optional uint32 previous_counter = 3; - optional bytes ciphertext = 4; - optional bytes pq_ratchet = 5; + optional bytes ratchet_key = 1; + optional uint32 counter = 2; + optional uint32 previous_counter = 3; + optional bytes ciphertext = 4; + optional bytes pq_ratchet = 5; + optional bytes recipient_address = 6; } message PreKeySignalMessage { diff --git a/rust/protocol/src/protocol.rs b/rust/protocol/src/protocol.rs index dfb64e9d4..67c756bac 100644 --- a/rust/protocol/src/protocol.rs +++ b/rust/protocol/src/protocol.rs @@ -12,7 +12,8 @@ use uuid::Uuid; use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId}; use crate::{ - IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError, Timestamp, kem, proto, + IdentityKey, PrivateKey, ProtocolAddress, PublicKey, Result, ServiceId, SignalProtocolError, + Timestamp, kem, proto, }; pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4; @@ -67,6 +68,7 @@ pub struct SignalMessage { previous_counter: u32, ciphertext: Box<[u8]>, pq_ratchet: spqr::SerializedState, + recipient_address: Option>, serialized: Box<[u8]>, } @@ -77,6 +79,7 @@ impl SignalMessage { pub fn new( message_version: u8, mac_key: &[u8], + recipient_address: Option<&ProtocolAddress>, sender_ratchet_key: PublicKey, counter: u32, previous_counter: u32, @@ -95,6 +98,7 @@ impl SignalMessage { } else { Some(pq_ratchet.to_vec()) }, + recipient_address: recipient_address.and_then(Self::serialize_recipient_address), }; let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH); serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION); @@ -116,6 +120,7 @@ impl SignalMessage { previous_counter, ciphertext: ciphertext.into(), pq_ratchet: pq_ratchet.to_vec(), + recipient_address: message.recipient_address.map(Into::into), serialized, }) } @@ -170,8 +175,41 @@ impl SignalMessage { hex::encode(their_mac), hex::encode(our_mac) ); + return Ok(false); + } + + Ok(true) + } + + pub fn verify_mac_with_recipient_address( + &self, + recipient_address: &ProtocolAddress, + sender_identity_key: &IdentityKey, + receiver_identity_key: &IdentityKey, + mac_key: &[u8], + ) -> Result { + if !self.verify_mac(sender_identity_key, receiver_identity_key, mac_key)? { + return Ok(false); + } + + // If the sender didn't include a recipient address, accept the message for + // backward compatibility with older clients. + let Some(encoded_recipient_address) = &self.recipient_address else { + return Ok(true); + }; + + // Only match valid Service IDs. + let Some(expected) = Self::serialize_recipient_address(recipient_address) else { + log::warn!("Local address not a valid Service ID {}", recipient_address); + return Ok(false); + }; + + if bool::from(expected.ct_eq(encoded_recipient_address.as_ref())) { + Ok(true) + } else { + log::warn!("Recipient address mismatch for {}", recipient_address); + Ok(false) } - Ok(result) } fn compute_mac( @@ -196,6 +234,15 @@ impl SignalMessage { .expect("enough bytes"); Ok(result) } + + /// Serializes the recipient address to Service-Id-Fixed-Width-Binary (17 bytes) + device ID + /// (1 byte). Returns `None` if the address name is not a valid ServiceId. + fn serialize_recipient_address(recipient_address: &ProtocolAddress) -> Option> { + let service_id = ServiceId::parse_from_service_id_string(recipient_address.name())?; + let mut bytes = service_id.service_id_fixed_width_binary().to_vec(); + bytes.push(recipient_address.device_id().into()); + Some(bytes) + } } impl AsRef<[u8]> for SignalMessage { @@ -247,6 +294,7 @@ impl TryFrom<&[u8]> for SignalMessage { previous_counter, ciphertext, pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]), + recipient_address: proto_structure.recipient_address.map(Into::into), serialized: Box::from(value), }) } @@ -918,10 +966,10 @@ pub fn extract_decryption_error_message_from_serialized_content( #[cfg(test)] mod tests { use rand::rngs::OsRng; - use rand::{CryptoRng, Rng, TryRngCore as _}; + use rand::{CryptoRng, Rng, RngCore, TryRngCore as _}; use super::*; - use crate::KeyPair; + use crate::{DeviceId, KeyPair}; fn create_signal_message(csprng: &mut T) -> Result where @@ -938,10 +986,13 @@ mod tests { let sender_ratchet_key_pair = KeyPair::generate(csprng); let sender_identity_key_pair = KeyPair::generate(csprng); let receiver_identity_key_pair = KeyPair::generate(csprng); + let recipient_address = + ProtocolAddress::new("recipient".to_owned(), DeviceId::new(1).unwrap()); SignalMessage::new( 4, &mac_key, + Some(&recipient_address), sender_ratchet_key_pair.public_key, 42, 41, @@ -958,6 +1009,7 @@ mod tests { assert_eq!(m1.counter, m2.counter); assert_eq!(m1.previous_counter, m2.previous_counter); assert_eq!(m1.ciphertext, m2.ciphertext); + assert_eq!(m1.recipient_address, m2.recipient_address); assert_eq!(m1.serialized, m2.serialized); } @@ -1025,6 +1077,109 @@ mod tests { Ok(()) } + #[test] + fn test_signal_message_verify_mac_accepts_legacy_message_without_recipient_address() + -> Result<()> { + let mut csprng = OsRng.unwrap_err(); + let mut mac_key = [0u8; 32]; + csprng.fill_bytes(&mut mac_key); + + let mut ciphertext = [0u8; 20]; + csprng.fill_bytes(&mut ciphertext); + + let sender_ratchet_key_pair = KeyPair::generate(&mut csprng); + let sender_identity_key_pair = KeyPair::generate(&mut csprng); + let receiver_identity_key_pair = KeyPair::generate(&mut csprng); + let recipient_address = ProtocolAddress::new( + "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + DeviceId::new(1).unwrap(), + ); + + let message = SignalMessage::new( + 4, + &mac_key, + Some(&recipient_address), + sender_ratchet_key_pair.public_key, + 42, + 41, + &ciphertext, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + b"", + )?; + + let mut proto_structure = proto::wire::SignalMessage::decode( + &message.serialized()[1..message.serialized().len() - SignalMessage::MAC_LENGTH], + ) + .expect("valid protobuf"); + proto_structure.recipient_address = None; + + let mut serialized = + vec![((message.message_version() & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION]; + proto_structure.encode(&mut serialized).expect("encodes"); + let mac = SignalMessage::compute_mac( + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + &mac_key, + &serialized, + )?; + serialized.extend_from_slice(&mac); + + let legacy_message = SignalMessage::try_from(serialized.as_slice())?; + assert!(legacy_message.verify_mac_with_recipient_address( + &recipient_address, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + &mac_key, + )?); + + Ok(()) + } + + #[test] + fn test_signal_message_verify_mac_rejects_wrong_recipient_address() -> Result<()> { + let mut csprng = OsRng.unwrap_err(); + let mut mac_key = [0u8; 32]; + csprng.fill_bytes(&mut mac_key); + + let mut ciphertext = [0u8; 20]; + csprng.fill_bytes(&mut ciphertext); + + let sender_ratchet_key_pair = KeyPair::generate(&mut csprng); + let sender_identity_key_pair = KeyPair::generate(&mut csprng); + let receiver_identity_key_pair = KeyPair::generate(&mut csprng); + let recipient_address = ProtocolAddress::new( + "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + DeviceId::new(1).unwrap(), + ); + let wrong_recipient_address = ProtocolAddress::new( + "a5e2f8d1-4b3c-4e7a-8f9d-1c2b3d4e5f6a".to_owned(), + DeviceId::new(1).unwrap(), + ); + + let message = SignalMessage::new( + 4, + &mac_key, + Some(&recipient_address), + sender_ratchet_key_pair.public_key, + 42, + 41, + &ciphertext, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + b"", + )?; + + assert!(!message.verify_mac_with_recipient_address( + &wrong_recipient_address, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + &mac_key, + )?); + + Ok(()) + } + #[test] fn test_sender_key_message_serialize_deserialize() -> Result<()> { let mut csprng = OsRng.unwrap_err(); diff --git a/rust/protocol/src/sealed_sender.rs b/rust/protocol/src/sealed_sender.rs index 62990b96f..a00d128a9 100644 --- a/rust/protocol/src/sealed_sender.rs +++ b/rust/protocol/src/sealed_sender.rs @@ -2027,6 +2027,7 @@ pub async fn sealed_sender_decrypt( usmc.sender()?.sender_uuid()?.to_string(), usmc.sender()?.sender_device_id()?, ); + let local_address = ProtocolAddress::new(local_uuid, local_device_id); let message = match usmc.msg_type()? { CiphertextMessageType::Whisper => { @@ -2045,6 +2046,7 @@ pub async fn sealed_sender_decrypt( session_cipher::message_decrypt_prekey( &ctext, &remote_address, + &local_address, session_store, identity_store, pre_key_store, diff --git a/rust/protocol/src/session_cipher.rs b/rust/protocol/src/session_cipher.rs index 7f6ce60e4..9347aeeb9 100644 --- a/rust/protocol/src/session_cipher.rs +++ b/rust/protocol/src/session_cipher.rs @@ -92,6 +92,7 @@ pub async fn message_encrypt( let message = SignalMessage::new( session_version, message_keys.mac_key(), + Some(remote_address), sender_ephemeral, chain_key.index(), previous_counter, @@ -120,6 +121,7 @@ pub async fn message_encrypt( CiphertextMessage::SignalMessage(SignalMessage::new( session_version, message_keys.mac_key(), + None, sender_ephemeral, chain_key.index(), previous_counter, @@ -162,6 +164,7 @@ pub async fn message_encrypt( pub async fn message_decrypt( ciphertext: &CiphertextMessage, remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, @@ -171,12 +174,14 @@ pub async fn message_decrypt( ) -> Result> { match ciphertext { CiphertextMessage::SignalMessage(m) => { + let _ = local_address; message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await } CiphertextMessage::PreKeySignalMessage(m) => { message_decrypt_prekey( m, remote_address, + local_address, session_store, identity_store, pre_key_store, @@ -197,6 +202,7 @@ pub async fn message_decrypt( pub async fn message_decrypt_prekey( ciphertext: &PreKeySignalMessage, remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, pre_key_store: &mut dyn PreKeyStore, @@ -241,6 +247,7 @@ pub async fn message_decrypt_prekey( let ptext = decrypt_message_with_record( remote_address, + Some(local_address), &mut session_record, ciphertext.message(), CiphertextMessageType::PreKey, @@ -291,6 +298,7 @@ pub async fn message_decrypt_signal( let ptext = decrypt_message_with_record( remote_address, + None, &mut session_record, ciphertext, CiphertextMessageType::Whisper, @@ -423,6 +431,7 @@ fn create_decryption_failure_log( fn decrypt_message_with_record( remote_address: &ProtocolAddress, + local_address: Option<&ProtocolAddress>, record: &mut SessionRecord, ciphertext: &SignalMessage, original_message_type: CiphertextMessageType, @@ -459,6 +468,7 @@ fn decrypt_message_with_record( ciphertext, original_message_type, remote_address, + local_address, csprng, ); @@ -522,6 +532,7 @@ fn decrypt_message_with_record( ciphertext, original_message_type, remote_address, + local_address, csprng, ); @@ -602,6 +613,7 @@ fn decrypt_message_with_state( ciphertext: &SignalMessage, original_message_type: CiphertextMessageType, remote_address: &ProtocolAddress, + local_address: Option<&ProtocolAddress>, csprng: &mut R, ) -> Result> { // Check for a completely empty or invalid session state before we do anything else. @@ -654,11 +666,19 @@ fn decrypt_message_with_state( "cannot decrypt without remote identity key", ))?; - let mac_valid = ciphertext.verify_mac( - &their_identity_key, - &state.local_identity_key()?, - message_keys.mac_key(), - )?; + let mac_valid = match local_address { + Some(local_address) => ciphertext.verify_mac_with_recipient_address( + local_address, + &their_identity_key, + &state.local_identity_key()?, + message_keys.mac_key(), + )?, + None => ciphertext.verify_mac( + &their_identity_key, + &state.local_identity_key()?, + message_keys.mac_key(), + )?, + }; if !mac_valid { return Err(SignalProtocolError::InvalidMessage( diff --git a/rust/protocol/test-support/src/lib.rs b/rust/protocol/test-support/src/lib.rs index bcb3c2c97..c4093b6b8 100644 --- a/rust/protocol/test-support/src/lib.rs +++ b/rust/protocol/test-support/src/lib.rs @@ -271,6 +271,7 @@ impl Participant { match message_decrypt( &incoming_message, &them.address, + &self.address, &mut self.state.store.session_store, &mut self.state.store.identity_store, &mut self.state.store.pre_key_store, diff --git a/rust/protocol/tests/sealed_sender.rs b/rust/protocol/tests/sealed_sender.rs index 320198689..d8297d782 100644 --- a/rust/protocol/tests/sealed_sender.rs +++ b/rust/protocol/tests/sealed_sender.rs @@ -936,6 +936,7 @@ fn test_decryption_error_in_sealed_sender() -> Result<(), SignalProtocolError> { message_decrypt( &bob_first_message, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, &mut alice_store.pre_key_store, diff --git a/rust/protocol/tests/session.rs b/rust/protocol/tests/session.rs index d3c48de66..7ae25bfdc 100644 --- a/rust/protocol/tests/session.rs +++ b/rust/protocol/tests/session.rs @@ -91,6 +91,7 @@ fn test_basic_prekey() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await?; @@ -130,7 +131,7 @@ fn test_basic_prekey() -> TestResult { assert_eq!(bob_outgoing.message_type(), CiphertextMessageType::Whisper); - let alice_decrypts = decrypt(alice_store, &bob_address, &bob_outgoing).await?; + let alice_decrypts = decrypt(alice_store, &bob_address, &alice_address, &bob_outgoing).await?; assert_eq!( String::from_utf8(alice_decrypts).expect("valid utf8"), @@ -172,7 +173,7 @@ fn test_basic_prekey() -> TestResult { encrypt(&mut alter_alice_store, &bob_address, original_message).await?; assert!(matches!( - decrypt(&mut bob_store_builder.store, &alice_address, &outgoing_message) + decrypt(&mut bob_store_builder.store, &alice_address, &bob_address, &outgoing_message) .await .unwrap_err(), SignalProtocolError::UntrustedIdentity(a) if a == alice_address @@ -195,6 +196,7 @@ fn test_basic_prekey() -> TestResult { let decrypted = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &outgoing_message, ) .await?; @@ -285,9 +287,14 @@ fn test_chain_jump_over_limit() -> TestResult { let too_far = encrypt(alice_store, &bob_address, "Now you have gone too far").await?; assert!( - decrypt(&mut bob_store_builder.store, &alice_address, &too_far,) - .await - .is_err() + decrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + &too_far, + ) + .await + .is_err() ); Ok(()) } @@ -345,7 +352,13 @@ fn test_chain_jump_over_limit_with_self() -> TestResult { let too_far = encrypt(a1_store, &a2_address, "This is the song that never ends").await?; - let ptext = decrypt(&mut a2_store_builder.store, &a1_address, &too_far).await?; + let ptext = decrypt( + &mut a2_store_builder.store, + &a1_address, + &a2_address, + &too_far, + ) + .await?; assert_eq!( String::from_utf8(ptext).unwrap(), "This is the song that never ends" @@ -490,6 +503,7 @@ fn test_repeat_bundle_message() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await?; @@ -505,7 +519,8 @@ fn test_repeat_bundle_message() -> TestResult { ) .await?; assert_eq!(bob_outgoing.message_type(), CiphertextMessageType::Whisper); - let alice_decrypts = decrypt(alice_store, &bob_address, &bob_outgoing).await?; + let alice_decrypts = + decrypt(alice_store, &bob_address, &alice_address, &bob_outgoing).await?; assert_eq!( String::from_utf8(alice_decrypts).expect("valid utf8"), original_message @@ -520,6 +535,7 @@ fn test_repeat_bundle_message() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message2, ) .await?; @@ -534,7 +550,8 @@ fn test_repeat_bundle_message() -> TestResult { original_message, ) .await?; - let alice_decrypts = decrypt(alice_store, &bob_address, &bob_outgoing).await?; + let alice_decrypts = + decrypt(alice_store, &bob_address, &alice_address, &bob_outgoing).await?; assert_eq!( String::from_utf8(alice_decrypts).expect("valid utf8"), original_message @@ -618,7 +635,7 @@ fn test_bad_message_bundle() -> TestResult { ); assert!( - decrypt(bob_store, &alice_address, &incoming_message,) + decrypt(bob_store, &alice_address, &bob_address, &incoming_message,) .await .is_err() ); @@ -628,7 +645,7 @@ fn test_bad_message_bundle() -> TestResult { PreKeySignalMessage::try_from(outgoing_message.as_slice())?, ); - let ptext = decrypt(bob_store, &alice_address, &incoming_message).await?; + let ptext = decrypt(bob_store, &alice_address, &bob_address, &incoming_message).await?; assert_eq!( String::from_utf8(ptext).expect("valid utf8"), @@ -708,6 +725,7 @@ fn test_optional_one_time_prekey() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await?; @@ -768,8 +786,16 @@ fn test_message_key_limits() -> TestResult { } assert_eq!( - String::from_utf8(decrypt(&mut bob_store, &alice_address, &inflight[1000],).await?) - .expect("valid utf8"), + String::from_utf8( + decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &inflight[1000], + ) + .await? + ) + .expect("valid utf8"), "It's over 1000" ); assert_eq!( @@ -777,6 +803,7 @@ fn test_message_key_limits() -> TestResult { decrypt( &mut bob_store, &alice_address, + &bob_address, &inflight[TOO_MANY_MESSAGES - 1], ) .await? @@ -785,7 +812,7 @@ fn test_message_key_limits() -> TestResult { format!("It's over {}", TOO_MANY_MESSAGES - 1) ); - let err = decrypt(&mut bob_store, &alice_address, &inflight[5]) + let err = decrypt(&mut bob_store, &alice_address, &bob_address, &inflight[5]) .await .unwrap_err(); assert!(matches!( @@ -877,6 +904,7 @@ fn test_basic_simultaneous_initiate() -> TestResult { let alice_plaintext = decrypt( alice_store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_alice.serialize(), )?), @@ -890,6 +918,7 @@ fn test_basic_simultaneous_initiate() -> TestResult { let bob_plaintext = decrypt( bob_store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_bob.serialize(), )?), @@ -923,6 +952,7 @@ fn test_basic_simultaneous_initiate() -> TestResult { let response_plaintext = decrypt( bob_store, &alice_address, + &bob_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( alice_response.serialize(), )?), @@ -944,6 +974,7 @@ fn test_basic_simultaneous_initiate() -> TestResult { let response_plaintext = decrypt( alice_store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1043,6 +1074,7 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { let bob_plaintext = decrypt( bob_store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1069,6 +1101,7 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { let response_plaintext = decrypt( bob_store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( alice_response.serialize(), )?), @@ -1090,6 +1123,7 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { let response_plaintext = decrypt( alice_store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1189,6 +1223,7 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { let alice_plaintext = decrypt( alice_store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_alice.serialize(), )?), @@ -1202,6 +1237,7 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { let bob_plaintext = decrypt( bob_store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1243,6 +1279,7 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { let response_plaintext = decrypt( alice_store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1348,6 +1385,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let alice_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_alice.serialize(), )?), @@ -1361,6 +1399,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let bob_plaintext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1419,6 +1458,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let alice_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( message_for_alice.serialize(), )?), @@ -1432,6 +1472,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let bob_plaintext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1492,6 +1533,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let response_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1624,6 +1666,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let alice_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_alice.serialize(), )?), @@ -1637,6 +1680,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let bob_plaintext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1695,6 +1739,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let alice_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( message_for_alice.serialize(), )?), @@ -1708,6 +1753,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let bob_plaintext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( message_for_bob.serialize(), )?), @@ -1768,6 +1814,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let response_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1791,6 +1838,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let blast_from_the_past = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::try_from( lost_message_for_bob.serialize(), )?), @@ -1819,6 +1867,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let response_plaintext = decrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, &CiphertextMessage::SignalMessage(SignalMessage::try_from( bob_response.serialize(), )?), @@ -1901,6 +1950,7 @@ fn test_zero_is_a_valid_prekey_id() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await?; @@ -2106,6 +2156,7 @@ fn prekey_message_failed_decryption_does_not_update_stores() -> TestResult { decrypt( &mut alice_store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(pre_key_message), ) .await, @@ -2176,9 +2227,14 @@ fn prekey_message_failed_decryption_does_not_update_stores_even_when_previously_ let bob_ciphertext = encrypt(&mut bob_store, &alice_address, "from Bob") .await .expect("valid"); - _ = decrypt(&mut alice_store, &bob_address, &bob_ciphertext) - .await - .expect("valid"); + _ = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_ciphertext, + ) + .await + .expect("valid"); // Alice archives the session because she feels like it. let mut alice_session_with_bob = alice_store @@ -2245,6 +2301,7 @@ fn prekey_message_failed_decryption_does_not_update_stores_even_when_previously_ decrypt( &mut alice_store, &bob_address, + &alice_address, &CiphertextMessage::PreKeySignalMessage(pre_key_message), ) .await, @@ -2323,9 +2380,14 @@ fn prekey_message_to_archived_session() -> TestResult { assert_eq!(bob_ciphertext.message_type(), CiphertextMessageType::PreKey); // Alice receives the message. - let received_message = decrypt(&mut alice_store, &bob_address, &bob_ciphertext) - .await - .expect("valid"); + let received_message = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_ciphertext, + ) + .await + .expect("valid"); assert_eq!(received_message, b"from Bob"); // Alice decides to archive the session and then send a message to Bob on a new session. @@ -2357,9 +2419,14 @@ fn prekey_message_to_archived_session() -> TestResult { bob_ciphertext_2.message_type(), CiphertextMessageType::PreKey ); - let received_message_2 = decrypt(&mut alice_store, &bob_address, &bob_ciphertext_2) - .await - .expect("valid"); + let received_message_2 = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_ciphertext_2, + ) + .await + .expect("valid"); assert_eq!(received_message_2, b"from Bob 2"); // This should promote Bob's session back to the front of Alice's session state. @@ -2410,7 +2477,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec let alice_plaintext = "This is Alice's message"; let alice_ciphertext = encrypt(&mut alice_store, &bob_address, alice_plaintext).await?; - let bob_decrypted = decrypt(&mut bob_store, &alice_address, &alice_ciphertext).await?; + let bob_decrypted = decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &alice_ciphertext, + ) + .await?; assert_eq!( String::from_utf8(bob_decrypted).expect("valid utf8"), alice_plaintext @@ -2419,7 +2492,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec let bob_plaintext = "This is Bob's reply"; let bob_ciphertext = encrypt(&mut bob_store, &alice_address, bob_plaintext).await?; - let alice_decrypted = decrypt(&mut alice_store, &bob_address, &bob_ciphertext).await?; + let alice_decrypted = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_ciphertext, + ) + .await?; assert_eq!( String::from_utf8(alice_decrypted).expect("valid utf8"), bob_plaintext @@ -2441,7 +2520,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec alice_messages.shuffle(&mut rng); for i in 0..ALICE_MESSAGE_COUNT / 2 { - let ptext = decrypt(&mut bob_store, &alice_address, &alice_messages[i].1).await?; + let ptext = decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &alice_messages[i].1, + ) + .await?; assert_eq!( String::from_utf8(ptext).expect("valid utf8"), alice_messages[i].0 @@ -2459,7 +2544,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec bob_messages.shuffle(&mut rng); for i in 0..BOB_MESSAGE_COUNT / 2 { - let ptext = decrypt(&mut alice_store, &bob_address, &bob_messages[i].1).await?; + let ptext = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_messages[i].1, + ) + .await?; assert_eq!( String::from_utf8(ptext).expect("valid utf8"), bob_messages[i].0 @@ -2467,7 +2558,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec } for i in ALICE_MESSAGE_COUNT / 2..ALICE_MESSAGE_COUNT { - let ptext = decrypt(&mut bob_store, &alice_address, &alice_messages[i].1).await?; + let ptext = decrypt( + &mut bob_store, + &alice_address, + &bob_address, + &alice_messages[i].1, + ) + .await?; assert_eq!( String::from_utf8(ptext).expect("valid utf8"), alice_messages[i].0 @@ -2475,7 +2572,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec } for i in BOB_MESSAGE_COUNT / 2..BOB_MESSAGE_COUNT { - let ptext = decrypt(&mut alice_store, &bob_address, &bob_messages[i].1).await?; + let ptext = decrypt( + &mut alice_store, + &bob_address, + &alice_address, + &bob_messages[i].1, + ) + .await?; assert_eq!( String::from_utf8(ptext).expect("valid utf8"), bob_messages[i].0 @@ -2499,7 +2602,7 @@ async fn run_interaction( let alice_message = encrypt(alice_store, bob_address, alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(bob_store, alice_address, &alice_message).await?) + String::from_utf8(decrypt(bob_store, alice_address, bob_address, &alice_message).await?) .expect("valid utf8"), alice_ptext ); @@ -2509,7 +2612,7 @@ async fn run_interaction( let bob_message = encrypt(bob_store, alice_address, bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(alice_store, bob_address, &bob_message).await?) + String::from_utf8(decrypt(alice_store, bob_address, alice_address, &bob_message).await?) .expect("valid utf8"), bob_ptext ); @@ -2519,8 +2622,10 @@ async fn run_interaction( let alice_message = encrypt(alice_store, bob_address, &alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(bob_store, alice_address, &alice_message).await?) - .expect("valid utf8"), + String::from_utf8( + decrypt(bob_store, alice_address, bob_address, &alice_message).await? + ) + .expect("valid utf8"), alice_ptext ); } @@ -2530,8 +2635,10 @@ async fn run_interaction( let bob_message = encrypt(bob_store, alice_address, &bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(alice_store, bob_address, &bob_message).await?) - .expect("valid utf8"), + String::from_utf8( + decrypt(alice_store, bob_address, alice_address, &bob_message).await? + ) + .expect("valid utf8"), bob_ptext ); } @@ -2549,8 +2656,10 @@ async fn run_interaction( let alice_message = encrypt(alice_store, bob_address, &alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(bob_store, alice_address, &alice_message).await?) - .expect("valid utf8"), + String::from_utf8( + decrypt(bob_store, alice_address, bob_address, &alice_message).await? + ) + .expect("valid utf8"), alice_ptext ); } @@ -2560,15 +2669,17 @@ async fn run_interaction( let bob_message = encrypt(bob_store, alice_address, &bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( - String::from_utf8(decrypt(alice_store, bob_address, &bob_message).await?) - .expect("valid utf8"), + String::from_utf8( + decrypt(alice_store, bob_address, alice_address, &bob_message).await? + ) + .expect("valid utf8"), bob_ptext ); } for (ptext, ctext) in alice_ooo_messages { assert_eq!( - String::from_utf8(decrypt(bob_store, alice_address, &ctext).await?) + String::from_utf8(decrypt(bob_store, alice_address, bob_address, &ctext).await?) .expect("valid utf8"), ptext ); @@ -2647,6 +2758,7 @@ fn test_signedprekey_not_saved() -> TestResult { let ptext = decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await?; @@ -2695,6 +2807,7 @@ fn test_signedprekey_not_saved() -> TestResult { decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &incoming_message, ) .await @@ -2896,14 +3009,14 @@ fn test_longer_sessions() -> TestResult { None => {} Some((_reordered, msg)) => { log::debug!("Process message to Alice"); - decrypt(alice_store, &bob_address, &msg).await?; + decrypt(alice_store, &bob_address, &alice_address, &msg).await?; } }, LongerSessionActions::BobRecv => match to_bob.pop_front() { None => {} Some((_reordered, msg)) => { log::debug!("Process message to Bob"); - decrypt(bob_store, &alice_address, &msg).await?; + decrypt(bob_store, &alice_address, &bob_address, &msg).await?; } }, LongerSessionActions::AliceDrop => { @@ -2987,8 +3100,8 @@ fn test_duplicate_message_error_returned() -> TestResult { .await?; let msg = encrypt(alice_store, &bob_address, "this_will_be_a_dup").await?; - decrypt(bob_store, &alice_address, &msg).await?; - let err = decrypt(bob_store, &alice_address, &msg) + decrypt(bob_store, &alice_address, &bob_address, &msg).await?; + let err = decrypt(bob_store, &alice_address, &bob_address, &msg) .await .expect_err("should be a duplicate"); assert!(matches!(err, SignalProtocolError::DuplicatedMessage(_, _))); @@ -3035,11 +3148,11 @@ fn test_pqr_state_and_message_contents_nonempty() -> TestResult { let msg = encrypt(alice_store, &bob_address, "msg1").await?; assert_matches!(&msg, CiphertextMessage::PreKeySignalMessage(m) if !m.message().pq_ratchet().is_empty()); - decrypt(bob_store, &alice_address, &msg).await?; + decrypt(bob_store, &alice_address, &bob_address, &msg).await?; let msg = encrypt(bob_store, &alice_address, "msg2").await?; assert_matches!(&msg, CiphertextMessage::SignalMessage(m) if !m.pq_ratchet().is_empty()); - decrypt(alice_store, &bob_address, &msg).await?; + decrypt(alice_store, &bob_address, &alice_address, &msg).await?; let msg = encrypt(alice_store, &bob_address, "msg3").await?; assert_matches!(&msg, CiphertextMessage::SignalMessage(m) if !m.pq_ratchet().is_empty()); @@ -3103,9 +3216,14 @@ fn x3dh_prekey_rejected_as_invalid_message_specifically() { .expect("valid"); let mut bob_one_off_store = bob_store_builder.store.clone(); - _ = support::decrypt(&mut bob_one_off_store, &alice_address, &pre_key_message) - .await - .expect("unmodified message is fine"); + _ = support::decrypt( + &mut bob_one_off_store, + &alice_address, + &bob_address, + &pre_key_message, + ) + .await + .expect("unmodified message is fine"); let original = assert_matches!(pre_key_message, CiphertextMessage::PreKeySignalMessage(m) => m); @@ -3124,6 +3242,7 @@ fn x3dh_prekey_rejected_as_invalid_message_specifically() { let err = support::decrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, &CiphertextMessage::PreKeySignalMessage(modified_message.clone()), ) .await @@ -3175,7 +3294,7 @@ fn x3dh_established_session_is_or_is_not_usable() { .expect("valid"); let bob_store = &mut bob_store_builder.store; - _ = support::decrypt(bob_store, &alice_address, &pre_key_message) + _ = support::decrypt(bob_store, &alice_address, &bob_address, &pre_key_message) .await .expect("unmodified message is fine"); @@ -3277,7 +3396,7 @@ fn prekey_message_sent_from_different_user_is_rejected() { .expect("valid"); let bob_store = &mut bob_store_builder.store; - _ = support::decrypt(bob_store, &alice_address, &pre_key_message) + _ = support::decrypt(bob_store, &alice_address, &bob_address, &pre_key_message) .await .expect("unmodified message is fine"); _ = bob_store @@ -3286,7 +3405,7 @@ fn prekey_message_sent_from_different_user_is_rejected() { .expect("can load sessions") .expect("session successfully created"); - let err = support::decrypt(bob_store, &mallory_address, &pre_key_message) + let err = support::decrypt(bob_store, &mallory_address, &bob_address, &pre_key_message) .await .expect_err("should be rejected"); assert_matches!( @@ -3306,6 +3425,65 @@ fn prekey_message_sent_from_different_user_is_rejected() { .expect("sync") } +#[test] +fn prekey_message_rejects_wrong_local_recipient_address() { + async { + let mut csprng = OsRng.unwrap_err(); + + let alice_address = ProtocolAddress::new( + "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + DeviceId::new(1).unwrap(), + ); + let bob_address = ProtocolAddress::new( + "796abedb-ca4e-4f18-8803-1fde5b921f9f".to_owned(), + DeviceId::new(1).unwrap(), + ); + let mallory_address = ProtocolAddress::new( + "e80f7bbe-5b94-471e-bd8c-2173654ea3d1".to_owned(), + DeviceId::new(1).unwrap(), + ); + + let mut bob_store_builder = TestStoreBuilder::new(); + bob_store_builder.add_pre_key(IdChoice::Next); + bob_store_builder.add_signed_pre_key(IdChoice::Next); + bob_store_builder.add_kyber_pre_key(IdChoice::Next); + + let bob_pre_key_bundle = + bob_store_builder.make_bundle_with_latest_keys(DeviceId::new(1).unwrap()); + + let mut alice_store = TestStoreBuilder::new().store; + process_prekey_bundle( + &bob_address, + &mut alice_store.session_store, + &mut alice_store.identity_store, + &bob_pre_key_bundle, + SystemTime::now(), + &mut csprng, + ) + .await + .expect("valid"); + + let pre_key_message = support::encrypt(&mut alice_store, &bob_address, "hi bob") + .await + .expect("valid"); + + let err = support::decrypt( + &mut bob_store_builder.store, + &alice_address, + &mallory_address, + &pre_key_message, + ) + .await + .expect_err("recipient binding should reject wrong local address"); + assert_matches!( + err, + SignalProtocolError::InvalidMessage(CiphertextMessageType::PreKey, "decryption failed") + ); + } + .now_or_never() + .expect("sync") +} + #[test] fn proptest_session_resets() { // This is the same test setup as fuzz/fuzz_targets/session_management.rs. If this test fails, diff --git a/rust/protocol/tests/support/mod.rs b/rust/protocol/tests/support/mod.rs index 5540eb96e..4bab798fc 100644 --- a/rust/protocol/tests/support/mod.rs +++ b/rust/protocol/tests/support/mod.rs @@ -48,12 +48,14 @@ pub async fn encrypt( pub async fn decrypt( store: &mut InMemSignalProtocolStore, remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, msg: &CiphertextMessage, ) -> Result, SignalProtocolError> { let mut csprng = OsRng.unwrap_err(); message_decrypt( msg, remote_address, + local_address, &mut store.session_store, &mut store.identity_store, &mut store.pre_key_store, diff --git a/swift/Sources/LibSignalClient/Protocol.swift b/swift/Sources/LibSignalClient/Protocol.swift index 1ea41a4d0..a513b3186 100644 --- a/swift/Sources/LibSignalClient/Protocol.swift +++ b/swift/Sources/LibSignalClient/Protocol.swift @@ -59,6 +59,7 @@ public func signalDecrypt( public func signalDecryptPreKey( message: PreKeySignalMessage, from address: ProtocolAddress, + localAddress: ProtocolAddress, sessionStore: SessionStore, identityStore: IdentityKeyStore, preKeyStore: PreKeyStore, @@ -66,7 +67,7 @@ public func signalDecryptPreKey( kyberPreKeyStore: KyberPreKeyStore, context: StoreContext ) throws -> Data { - return try withAllBorrowed(message, address) { messageHandle, addressHandle in + return try withAllBorrowed(message, address, localAddress) { messageHandle, addressHandle, localAddressHandle in try withSessionStore(sessionStore, context) { ffiSessionStore in try withIdentityKeyStore(identityStore, context) { ffiIdentityStore in try withPreKeyStore(preKeyStore, context) { ffiPreKeyStore in @@ -77,6 +78,7 @@ public func signalDecryptPreKey( $0, messageHandle.const(), addressHandle.const(), + localAddressHandle.const(), ffiSessionStore, ffiIdentityStore, ffiPreKeyStore, diff --git a/swift/Sources/SignalFfi/signal_ffi.h b/swift/Sources/SignalFfi/signal_ffi.h index e4d7ab6cc..cb39ce7c5 100644 --- a/swift/Sources/SignalFfi/signal_ffi.h +++ b/swift/Sources/SignalFfi/signal_ffi.h @@ -1903,7 +1903,7 @@ SignalFfiError *signal_create_call_link_credential_response_check_valid_contents SignalFfiError *signal_decrypt_message(SignalOwnedBuffer *out, SignalConstPointerSignalMessage message, SignalConstPointerProtocolAddress protocol_address, SignalConstPointerFfiSessionStoreStruct session_store, SignalConstPointerFfiIdentityKeyStoreStruct identity_key_store); -SignalFfiError *signal_decrypt_pre_key_message(SignalOwnedBuffer *out, SignalConstPointerPreKeySignalMessage message, SignalConstPointerProtocolAddress protocol_address, SignalConstPointerFfiSessionStoreStruct session_store, SignalConstPointerFfiIdentityKeyStoreStruct identity_key_store, SignalConstPointerFfiPreKeyStoreStruct prekey_store, SignalConstPointerFfiSignedPreKeyStoreStruct signed_prekey_store, SignalConstPointerFfiKyberPreKeyStoreStruct kyber_prekey_store); +SignalFfiError *signal_decrypt_pre_key_message(SignalOwnedBuffer *out, SignalConstPointerPreKeySignalMessage message, SignalConstPointerProtocolAddress protocol_address, SignalConstPointerProtocolAddress local_address, SignalConstPointerFfiSessionStoreStruct session_store, SignalConstPointerFfiIdentityKeyStoreStruct identity_key_store, SignalConstPointerFfiPreKeyStoreStruct prekey_store, SignalConstPointerFfiSignedPreKeyStoreStruct signed_prekey_store, SignalConstPointerFfiKyberPreKeyStoreStruct kyber_prekey_store); SignalFfiError *signal_decryption_error_message_clone(SignalMutPointerDecryptionErrorMessage *new_obj, SignalConstPointerDecryptionErrorMessage obj); diff --git a/swift/Tests/LibSignalClientTests/SessionTests.swift b/swift/Tests/LibSignalClientTests/SessionTests.swift index f3a7c69c7..0f192738e 100644 --- a/swift/Tests/LibSignalClientTests/SessionTests.swift +++ b/swift/Tests/LibSignalClientTests/SessionTests.swift @@ -43,6 +43,7 @@ class SessionTests: TestCaseBase { let ptext_b = try! signalDecryptPreKey( message: ctext_b, from: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, preKeyStore: bob_store, @@ -111,6 +112,7 @@ class SessionTests: TestCaseBase { try signalDecryptPreKey( message: ctext_b, from: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, preKeyStore: bob_store, @@ -299,6 +301,7 @@ class SessionTests: TestCaseBase { _ = try! signalDecryptPreKey( message: ctext_b, from: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, preKeyStore: bob_store, @@ -311,6 +314,7 @@ class SessionTests: TestCaseBase { _ = try signalDecryptPreKey( message: ctext_b, from: mallory_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, preKeyStore: bob_store, @@ -419,6 +423,7 @@ class SessionTests: TestCaseBase { let plaintext = try signalDecryptPreKey( message: try! PreKeySignalMessage(bytes: usmc.contents), from: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, preKeyStore: bob_store, @@ -769,6 +774,7 @@ class SessionTests: TestCaseBase { _ = try signalDecryptPreKey( message: PreKeySignalMessage(bytes: bob_first_message), from: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, preKeyStore: alice_store,