diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index ee2876fce..74b356982 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,7 +1,7 @@ v0.91.0 - Support gRPC for getUploadForm() -- 1:1 message decryption now takes the local address as an extra argument +- 1:1 message encryption and decryption now takes the local address as an extra argument - Add `UserBasedAuthorization.UnrestrictedUnauthenticatedAccess` / `unrestrictedUnauthenticatedAccess` / `'unrestricted'` for `UnauthKeysService.getPreKeys` (and for 1:1 sealed sender messages in the future). diff --git a/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java b/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java index c151208ed..dd98eb3d9 100644 --- a/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java +++ b/java/client/src/main/java/org/signal/libsignal/metadata/SealedSessionCipher.java @@ -60,8 +60,11 @@ public class SealedSessionCipher { SenderCertificate senderCertificate, byte[] paddedPlaintext) throws InvalidKeyException, NoSessionException, UntrustedIdentityException { + SignalProtocolAddress localAddress = + new SignalProtocolAddress(this.localUuidAddress, this.localDeviceId); CiphertextMessage message = - new SessionCipher(signalProtocolStore, destinationAddress).encrypt(paddedPlaintext); + new SessionCipher(signalProtocolStore, localAddress, destinationAddress) + .encrypt(paddedPlaintext); UnidentifiedSenderMessageContent content = new UnidentifiedSenderMessageContent( message, @@ -237,11 +240,13 @@ public class SealedSessionCipher { } public int getSessionVersion(SignalProtocolAddress remoteAddress) { - return new SessionCipher(signalProtocolStore, remoteAddress).getSessionVersion(); + return new SessionCipher(signalProtocolStore, localAddress(), remoteAddress) + .getSessionVersion(); } public int getRemoteRegistrationId(SignalProtocolAddress remoteAddress) { - return new SessionCipher(signalProtocolStore, remoteAddress).getRemoteRegistrationId(); + return new SessionCipher(signalProtocolStore, localAddress(), remoteAddress) + .getRemoteRegistrationId(); } private byte[] decrypt(UnidentifiedSenderMessageContent message) @@ -260,13 +265,11 @@ public class SealedSessionCipher { switch (message.getType()) { case CiphertextMessage.WHISPER_TYPE: - return new SessionCipher(signalProtocolStore, sender) + return new SessionCipher(signalProtocolStore, localAddress(), sender) .decrypt(new SignalMessage(message.getContent())); case CiphertextMessage.PREKEY_TYPE: - return new SessionCipher(signalProtocolStore, sender) - .decrypt( - new PreKeySignalMessage(message.getContent()), - new SignalProtocolAddress(localUuidAddress, localDeviceId)); + return new SessionCipher(signalProtocolStore, localAddress(), sender) + .decrypt(new PreKeySignalMessage(message.getContent())); case CiphertextMessage.SENDERKEY_TYPE: return new GroupCipher(signalProtocolStore, sender).decrypt(message.getContent()); case CiphertextMessage.PLAINTEXT_CONTENT_TYPE: @@ -339,4 +342,8 @@ public class SealedSessionCipher { return groupId; } } + + private SignalProtocolAddress localAddress() { + return new SignalProtocolAddress(this.localUuidAddress, this.localDeviceId); + } } diff --git a/java/client/src/test/java/org/signal/libsignal/metadata/SealedSessionCipherTest.java b/java/client/src/test/java/org/signal/libsignal/metadata/SealedSessionCipherTest.java index e763b973e..a9ba1caa2 100644 --- a/java/client/src/test/java/org/signal/libsignal/metadata/SealedSessionCipherTest.java +++ b/java/client/src/test/java/org/signal/libsignal/metadata/SealedSessionCipherTest.java @@ -882,7 +882,9 @@ public class SealedSessionCipherTest extends TestCase { // Pretend Bob's reply fails to decrypt. SignalProtocolAddress aliceAddress = new SignalProtocolAddress("9d0652a3-dcc3-4d11-975f-74d61598733f", 1); - SessionCipher bobUnsealedCipher = new SessionCipher(bobStore, aliceAddress); + SignalProtocolAddress bobAddressForReply = + new SignalProtocolAddress("e80f7bbe-5b94-471e-bd8c-2173654ea3d1", 1); + SessionCipher bobUnsealedCipher = new SessionCipher(bobStore, bobAddressForReply, aliceAddress); CiphertextMessage bobMessage = bobUnsealedCipher.encrypt("reply".getBytes()); DecryptionErrorMessage errorMessage = diff --git a/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java b/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java index f97f6f63a..cb13c6207 100644 --- a/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java +++ b/java/client/src/test/java/org/signal/libsignal/protocol/SessionBuilderTest.java @@ -78,15 +78,16 @@ public class SessionBuilderTest { assertTrue(aliceStore.loadSession(BOB_ADDRESS).getSessionVersion() == expectedVersion); String originalMessage = "initial hello!"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = + new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessage.getType() == CiphertextMessage.PREKEY_TYPE); PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessage.serialize()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -132,7 +133,7 @@ public class SessionBuilderTest { aliceStore = new TestInMemorySignalProtocolStore(); var aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); - var aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + var aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); PreKeyBundle anotherBundle = bundleFactory.createBundle(bobStore); aliceSessionBuilder.process(anotherBundle); @@ -140,9 +141,9 @@ public class SessionBuilderTest { String originalMessage = "Good, fast, cheap: pick two"; var outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); - var bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + var bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize())); fail("shouldn't be trusted!"); } catch (UntrustedIdentityException uie) { bobStore.saveIdentity( @@ -150,8 +151,7 @@ public class SessionBuilderTest { } var plaintext = - bobSessionCipher.decrypt( - new PreKeySignalMessage(outgoingMessage.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(outgoingMessage.serialize())); assertTrue(new String(plaintext).equals(originalMessage)); Random random = new Random(); @@ -196,7 +196,7 @@ public class SessionBuilderTest { aliceSessionBuilder.process(bobPreKey); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessageOne = aliceSessionCipher.encrypt(originalMessage.getBytes()); CiphertextMessage outgoingMessageTwo = aliceSessionCipher.encrypt(originalMessage.getBytes()); @@ -205,9 +205,9 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessageOne.serialize()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); assertTrue(originalMessage.equals(new String(plaintext))); CiphertextMessage bobOutgoingMessage = bobSessionCipher.encrypt(originalMessage.getBytes()); @@ -221,9 +221,7 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessageTwo = new PreKeySignalMessage(outgoingMessageTwo.serialize()); - plaintext = - bobSessionCipher.decrypt( - new PreKeySignalMessage(incomingMessageTwo.serialize()), BOB_ADDRESS); + plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(incomingMessageTwo.serialize())); assertTrue(originalMessage.equals(new String(plaintext))); bobOutgoingMessage = bobSessionCipher.encrypt(originalMessage.getBytes()); @@ -261,7 +259,7 @@ public class SessionBuilderTest { assertTrue(aliceStore.loadSession(BOB_ADDRESS).getSessionVersion() == expectedVersion); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessage.getType() == CiphertextMessage.PREKEY_TYPE); @@ -269,8 +267,8 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessage.serialize()); assertTrue(!incomingMessage.getPreKeyId().isPresent()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - byte[] plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); + byte[] plaintext = bobSessionCipher.decrypt(incomingMessage); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -302,7 +300,7 @@ public class SessionBuilderTest { assertFalse(initialSession.hasSenderChain(Instant.EPOCH.plus(90, ChronoUnit.DAYS))); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes(), Instant.EPOCH); @@ -350,7 +348,7 @@ public class SessionBuilderTest { assertTrue(aliceStore.loadSession(BOB_ADDRESS).getSessionVersion() == expectedVersion); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessage.getType() == CiphertextMessage.PREKEY_TYPE); @@ -358,16 +356,16 @@ public class SessionBuilderTest { PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessage.serialize()); assertTrue(!incomingMessage.getPreKeyId().isPresent()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); - bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); + bobSessionCipher.decrypt(incomingMessage); assertTrue(bobStore.containsSession(ALICE_ADDRESS)); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); - SessionCipher bobSessionCipherForMallory = new SessionCipher(bobStore, MALLORY_ADDRESS); + SessionCipher bobSessionCipherForMallory = + new SessionCipher(bobStore, BOB_ADDRESS, MALLORY_ADDRESS); assertThrows( - ReusedBaseKeyException.class, - () -> bobSessionCipherForMallory.decrypt(incomingMessage, BOB_ADDRESS)); + ReusedBaseKeyException.class, () -> bobSessionCipherForMallory.decrypt(incomingMessage)); } } @@ -491,7 +489,7 @@ public class SessionBuilderTest { aliceSessionBuilder.process(bobPreKey); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessageOne = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessageOne.getType() == CiphertextMessage.PREKEY_TYPE); @@ -503,12 +501,12 @@ public class SessionBuilderTest { badMessage[badMessage.length - 10] ^= 0x01; PreKeySignalMessage incomingMessage = new PreKeySignalMessage(badMessage); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); byte[] plaintext = new byte[0]; try { - plaintext = bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + plaintext = bobSessionCipher.decrypt(incomingMessage); fail("Decrypt should have failed!"); } catch (InvalidMessageException e) { // good. @@ -516,7 +514,7 @@ public class SessionBuilderTest { assertTrue(bobStore.containsPreKey(bobPreKey.getPreKeyId())); - plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(goodMessage), BOB_ADDRESS); + plaintext = bobSessionCipher.decrypt(new PreKeySignalMessage(goodMessage)); assertTrue(originalMessage.equals(new String(plaintext))); assertFalse(bobStore.containsPreKey(bobPreKey.getPreKeyId())); @@ -541,16 +539,16 @@ public class SessionBuilderTest { aliceSessionBuilder.process(bobPreKey); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessageOne = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessageOne.getType() == CiphertextMessage.PREKEY_TYPE); PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessageOne.serialize()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + bobSessionCipher.decrypt(incomingMessage); fail("Decrypt should have failed!"); } catch (InvalidKeyIdException e) { assertEquals( @@ -578,16 +576,16 @@ public class SessionBuilderTest { aliceSessionBuilder.process(bobPreKey); String originalMessage = "Good, fast, cheap: pick two"; - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); CiphertextMessage outgoingMessageOne = aliceSessionCipher.encrypt(originalMessage.getBytes()); assertTrue(outgoingMessageOne.getType() == CiphertextMessage.PREKEY_TYPE); PreKeySignalMessage incomingMessage = new PreKeySignalMessage(outgoingMessageOne.serialize()); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); try { - bobSessionCipher.decrypt(incomingMessage, BOB_ADDRESS); + bobSessionCipher.decrypt(incomingMessage); fail("Decrypt should have failed!"); } catch (InvalidKeyIdException e) { fail("libsignal swallowed the exception"); @@ -605,8 +603,8 @@ public class SessionBuilderTest { InvalidKeyException, NoSessionException, UntrustedIdentityException { - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); String originalMessage = "smert ze smert"; CiphertextMessage aliceMessage = aliceSessionCipher.encrypt(originalMessage.getBytes()); diff --git a/java/client/src/test/java/org/signal/libsignal/protocol/SessionCipherTest.java b/java/client/src/test/java/org/signal/libsignal/protocol/SessionCipherTest.java index 517141082..0a4962d07 100644 --- a/java/client/src/test/java/org/signal/libsignal/protocol/SessionCipherTest.java +++ b/java/client/src/test/java/org/signal/libsignal/protocol/SessionCipherTest.java @@ -49,14 +49,14 @@ public class SessionCipherTest extends TestCase { SignalProtocolStore aliceStore = new TestInMemorySignalProtocolStore(); SignalProtocolStore bobStore = new TestInMemorySignalProtocolStore(); + SignalProtocolAddress bobAddress = new SignalProtocolAddress("+14159999999", 1); + SignalProtocolAddress aliceAddress = new SignalProtocolAddress("+14158888888", 1); - aliceStore.storeSession(new SignalProtocolAddress("+14159999999", 1), sessions.aliceSession); - bobStore.storeSession(new SignalProtocolAddress("+14158888888", 1), sessions.bobSession); + aliceStore.storeSession(bobAddress, sessions.aliceSession); + bobStore.storeSession(aliceAddress, sessions.bobSession); - SessionCipher aliceCipher = - new SessionCipher(aliceStore, new SignalProtocolAddress("+14159999999", 1)); - SessionCipher bobCipher = - new SessionCipher(bobStore, new SignalProtocolAddress("+14158888888", 1)); + SessionCipher aliceCipher = new SessionCipher(aliceStore, aliceAddress, bobAddress); + SessionCipher bobCipher = new SessionCipher(bobStore, bobAddress, aliceAddress); List inflight = new LinkedList<>(); @@ -88,8 +88,8 @@ public class SessionCipherTest extends TestCase { aliceStore.storeSession(bobAddress, sessions.aliceSession); bobStore.storeSession(aliceAddress, sessions.bobSession); - SessionCipher aliceCipher = new SessionCipher(aliceStore, bobAddress); - SessionCipher bobCipher = new SessionCipher(bobStore, aliceAddress); + SessionCipher aliceCipher = new SessionCipher(aliceStore, aliceAddress, bobAddress); + SessionCipher bobCipher = new SessionCipher(bobStore, bobAddress, aliceAddress); byte[] alicePlaintext = "This is a plaintext message.".getBytes(); CiphertextMessage message = aliceCipher.encrypt(alicePlaintext); @@ -121,8 +121,8 @@ public class SessionCipherTest extends TestCase { aliceStore.storeSession(bobAddress, sessions.aliceSession); bobStore.storeSession(aliceAddress, sessions.bobSession); - SessionCipher aliceCipher = new SessionCipher(aliceStore, bobAddress); - SessionCipher bobCipher = new SessionCipher(bobStore, aliceAddress); + SessionCipher aliceCipher = new SessionCipher(aliceStore, aliceAddress, bobAddress); + SessionCipher bobCipher = new SessionCipher(bobStore, bobAddress, aliceAddress); byte[] alicePlaintext = "This is a plaintext message.".getBytes(); CiphertextMessage message = aliceCipher.encrypt(alicePlaintext); @@ -151,14 +151,14 @@ public class SessionCipherTest extends TestCase { UntrustedIdentityException { SignalProtocolStore aliceStore = new TestInMemorySignalProtocolStore(); SignalProtocolStore bobStore = new TestInMemorySignalProtocolStore(); + SignalProtocolAddress bobAddress = new SignalProtocolAddress("+14159999999", 1); + SignalProtocolAddress aliceAddress = new SignalProtocolAddress("+14158888888", 1); - aliceStore.storeSession(new SignalProtocolAddress("+14159999999", 1), aliceSessionRecord); - bobStore.storeSession(new SignalProtocolAddress("+14158888888", 1), bobSessionRecord); + aliceStore.storeSession(bobAddress, aliceSessionRecord); + bobStore.storeSession(aliceAddress, bobSessionRecord); - SessionCipher aliceCipher = - new SessionCipher(aliceStore, new SignalProtocolAddress("+14159999999", 1)); - SessionCipher bobCipher = - new SessionCipher(bobStore, new SignalProtocolAddress("+14158888888", 1)); + SessionCipher aliceCipher = new SessionCipher(aliceStore, aliceAddress, bobAddress); + SessionCipher bobCipher = new SessionCipher(bobStore, bobAddress, aliceAddress); byte[] alicePlaintext = "This is a plaintext message.".getBytes(); CiphertextMessage message = aliceCipher.encrypt(alicePlaintext); diff --git a/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java b/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java index cf512f4c5..596e5425d 100644 --- a/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java +++ b/java/client/src/test/java/org/signal/libsignal/protocol/SimultaneousInitiateTests.java @@ -63,8 +63,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); aliceSessionBuilder.process(bobPreKeyBundle); bobSessionBuilder.process(alicePreKeyBundle); @@ -79,10 +79,9 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt( - new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); + aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -131,8 +130,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); aliceSessionBuilder.process(bobPreKeyBundle); bobSessionBuilder.process(alicePreKeyBundle); @@ -146,7 +145,7 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(bobPlaintext).equals("hey there")); assertEquals(bobStore.loadSession(ALICE_ADDRESS).getSessionVersion(), expectedVersion); @@ -156,7 +155,7 @@ public class SimultaneousInitiateTests { assertEquals(aliceResponse.getType(), CiphertextMessage.PREKEY_TYPE); byte[] responsePlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(aliceResponse.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(aliceResponse.serialize())); assertTrue(new String(responsePlaintext).equals("second message")); assertSessionIdEquals(aliceStore, bobStore); @@ -190,8 +189,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); aliceSessionBuilder.process(bobPreKeyBundle); bobSessionBuilder.process(alicePreKeyBundle); @@ -205,10 +204,9 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt( - new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); + aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -253,8 +251,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); aliceSessionBuilder.process(bobPreKeyBundle); bobSessionBuilder.process(alicePreKeyBundle); @@ -268,10 +266,9 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt( - new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); + aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -338,8 +335,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); for (int i = 0; i < 15; i++) { PreKeyBundle alicePreKeyBundle = bundleFactory.createBundle(aliceStore); @@ -357,10 +354,9 @@ public class SimultaneousInitiateTests { assertSessionIdNotEquals(aliceStore, bobStore); byte[] alicePlaintext = - aliceSessionCipher.decrypt( - new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); + aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -428,8 +424,8 @@ public class SimultaneousInitiateTests { SessionBuilder aliceSessionBuilder = new SessionBuilder(aliceStore, BOB_ADDRESS); SessionBuilder bobSessionBuilder = new SessionBuilder(bobStore, ALICE_ADDRESS); - SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, BOB_ADDRESS); - SessionCipher bobSessionCipher = new SessionCipher(bobStore, ALICE_ADDRESS); + SessionCipher aliceSessionCipher = new SessionCipher(aliceStore, ALICE_ADDRESS, BOB_ADDRESS); + SessionCipher bobSessionCipher = new SessionCipher(bobStore, BOB_ADDRESS, ALICE_ADDRESS); PreKeyBundle bobLostPreKeyBundle = bundleFactory.createBundle(bobStore); @@ -453,10 +449,9 @@ public class SimultaneousInitiateTests { assertFalse(isSessionIdEqual(aliceStore, bobStore)); byte[] alicePlaintext = - aliceSessionCipher.decrypt( - new PreKeySignalMessage(messageForAlice.serialize()), ALICE_ADDRESS); + aliceSessionCipher.decrypt(new PreKeySignalMessage(messageForAlice.serialize())); byte[] bobPlaintext = - bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(messageForBob.serialize())); assertTrue(new String(alicePlaintext).equals("sample message")); assertTrue(new String(bobPlaintext).equals("hey there")); @@ -508,8 +503,7 @@ public class SimultaneousInitiateTests { assertTrue(isSessionIdEqual(aliceStore, bobStore)); byte[] lostMessagePlaintext = - bobSessionCipher.decrypt( - new PreKeySignalMessage(lostMessageForBob.serialize()), BOB_ADDRESS); + bobSessionCipher.decrypt(new PreKeySignalMessage(lostMessageForBob.serialize())); assertTrue(new String(lostMessagePlaintext).equals("hey there")); assertFalse(isSessionIdEqual(aliceStore, bobStore)); diff --git a/java/shared/java/org/signal/libsignal/internal/Native.kt b/java/shared/java/org/signal/libsignal/internal/Native.kt index 81bbf8006..b0e28135b 100644 --- a/java/shared/java/org/signal/libsignal/internal/Native.kt +++ b/java/shared/java/org/signal/libsignal/internal/Native.kt @@ -1184,7 +1184,7 @@ internal object Native { @JvmStatic @Throws(Exception::class) public external fun SessionCipher_DecryptSignalMessage(message: ObjectHandle, protocolAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore): ByteArray @JvmStatic @Throws(Exception::class) - public external fun SessionCipher_EncryptMessage(ptext: ByteArray, protocolAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Long): CiphertextMessage + public external fun SessionCipher_EncryptMessage(ptext: ByteArray, protocolAddress: ObjectHandle, localAddress: ObjectHandle, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Long): CiphertextMessage @JvmStatic @Throws(Exception::class) public external fun SessionRecord_ArchiveCurrentState(sessionRecord: ObjectHandle): Unit diff --git a/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java b/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java index e740e62dc..a6341943f 100644 --- a/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java +++ b/java/shared/java/org/signal/libsignal/protocol/SessionCipher.java @@ -43,6 +43,7 @@ public class SessionCipher { private final PreKeyStore preKeyStore; private final SignedPreKeyStore signedPreKeyStore; private final KyberPreKeyStore kyberPreKeyStore; + private final SignalProtocolAddress localAddress; private final SignalProtocolAddress remoteAddress; /** @@ -59,17 +60,22 @@ public class SessionCipher { SignedPreKeyStore signedPreKeyStore, KyberPreKeyStore kyberPreKeyStore, IdentityKeyStore identityKeyStore, + SignalProtocolAddress localAddress, SignalProtocolAddress remoteAddress) { this.sessionStore = sessionStore; this.preKeyStore = preKeyStore; this.identityKeyStore = identityKeyStore; + this.localAddress = localAddress; this.remoteAddress = remoteAddress; this.signedPreKeyStore = signedPreKeyStore; this.kyberPreKeyStore = kyberPreKeyStore; } - public SessionCipher(SignalProtocolStore store, SignalProtocolAddress remoteAddress) { - this(store, store, store, store, store, remoteAddress); + public SessionCipher( + SignalProtocolStore store, + SignalProtocolAddress localAddress, + SignalProtocolAddress remoteAddress) { + this(store, store, store, store, store, localAddress, remoteAddress); } /** Exposed as public for {@code SealedSessionCipher}, do not use directly. */ @@ -154,7 +160,8 @@ public class SessionCipher { */ public CiphertextMessage encrypt(byte[] paddedMessage, Instant now) throws NoSessionException, UntrustedIdentityException { - try (NativeHandleGuard remoteAddress = new NativeHandleGuard(this.remoteAddress)) { + try (NativeHandleGuard remoteAddress = new NativeHandleGuard(this.remoteAddress); + NativeHandleGuard localAddressGuard = new NativeHandleGuard(this.localAddress); ) { return filterExceptions( NoSessionException.class, UntrustedIdentityException.class, @@ -162,6 +169,7 @@ public class SessionCipher { Native.SessionCipher_EncryptMessage( paddedMessage, remoteAddress.nativeHandle(), + localAddressGuard.nativeHandle(), bridge(sessionStore), _bridge(identityKeyStore), now.toEpochMilli())); @@ -181,7 +189,7 @@ public class SessionCipher { * @throws InvalidKeyException when the message is formatted incorrectly. * @throws UntrustedIdentityException when the {@link IdentityKey} of the sender is untrusted. */ - public byte[] decrypt(PreKeySignalMessage ciphertext, SignalProtocolAddress localAddress) + public byte[] decrypt(PreKeySignalMessage ciphertext) throws DuplicateMessageException, InvalidMessageException, InvalidKeyIdException, @@ -189,7 +197,7 @@ public class SessionCipher { UntrustedIdentityException { try (NativeHandleGuard ciphertextGuard = new NativeHandleGuard(ciphertext); NativeHandleGuard remoteAddressGuard = new NativeHandleGuard(this.remoteAddress); - NativeHandleGuard localAddressGuard = new NativeHandleGuard(localAddress); ) { + NativeHandleGuard localAddressGuard = new NativeHandleGuard(this.localAddress); ) { return filterExceptions( DuplicateMessageException.class, InvalidMessageException.class, diff --git a/node/ts/Native.ts b/node/ts/Native.ts index 126e7587f..2195cc4c4 100644 --- a/node/ts/Native.ts +++ b/node/ts/Native.ts @@ -312,7 +312,7 @@ type NativeFunctions = { SealedSenderDecryptionResult_GetDeviceId: (obj: Wrapper) => number; SealedSenderDecryptionResult_Message: (obj: Wrapper) => Uint8Array; SessionBuilder_ProcessPreKeyBundle: (bundle: Wrapper, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Timestamp) => Promise; - SessionCipher_EncryptMessage: (ptext: Uint8Array, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Timestamp) => Promise; + SessionCipher_EncryptMessage: (ptext: Uint8Array, protocolAddress: Wrapper, localAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, now: Timestamp) => Promise; SessionCipher_DecryptSignalMessage: (message: Wrapper, protocolAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore) => Promise>; SessionCipher_DecryptPreKeySignalMessage: (message: Wrapper, protocolAddress: Wrapper, localAddress: Wrapper, sessionStore: SessionStore, identityKeyStore: IdentityKeyStore, prekeyStore: PreKeyStore, signedPrekeyStore: SignedPreKeyStore, kyberPrekeyStore: KyberPreKeyStore) => Promise>; SealedSender_Encrypt: (destination: Wrapper, content: Wrapper, identityKeyStore: IdentityKeyStore) => Promise>; diff --git a/node/ts/index.ts b/node/ts/index.ts index 75018d408..4ef6cdd17 100644 --- a/node/ts/index.ts +++ b/node/ts/index.ts @@ -1291,6 +1291,7 @@ export function processPreKeyBundle( export async function signalEncrypt( message: Uint8Array, address: ProtocolAddress, + localAddress: ProtocolAddress, sessionStore: SessionStore, identityStore: IdentityKeyStore, now: Date = new Date() @@ -1299,6 +1300,7 @@ export async function signalEncrypt( await Native.SessionCipher_EncryptMessage( message, address, + localAddress, bridgeSessionStore(sessionStore), bridgeIdentityKeyStore(identityStore), now.getTime() @@ -1415,9 +1417,14 @@ export async function sealedSenderEncryptMessage( sessionStore: SessionStore, identityStore: IdentityKeyStore ): Promise> { + const localAddress = ProtocolAddress.new( + senderCert.senderUuid(), + senderCert.senderDeviceId() + ); const ciphertext = await signalEncrypt( message, address, + localAddress, sessionStore, identityStore ); diff --git a/node/ts/test/protocol/ProtocolTest.ts b/node/ts/test/protocol/ProtocolTest.ts index 8634030e9..041c07df9 100644 --- a/node/ts/test/protocol/ProtocolTest.ts +++ b/node/ts/test/protocol/ProtocolTest.ts @@ -224,6 +224,7 @@ it('DecryptionErrorMessage', async () => { const bCiphertext = await SignalClient.signalEncrypt( Buffer.from('reply', 'utf8'), aAddress, + bAddress, bSess, bKeys ); @@ -740,6 +741,7 @@ for (const testCase of sessionVersionTestCases) { const aCiphertext = await SignalClient.signalEncrypt( aMessage, bAddress, + aAddress, aliceStores.session, aliceStores.identity ); @@ -773,6 +775,7 @@ for (const testCase of sessionVersionTestCases) { const bCiphertext = await SignalClient.signalEncrypt( bMessage, aAddress, + bAddress, bobStores.session, bobStores.identity ); @@ -837,6 +840,7 @@ for (const testCase of sessionVersionTestCases) { const aCiphertext = await SignalClient.signalEncrypt( aMessage, bAddress, + aAddress, aliceStores.session, aliceStores.identity ); @@ -892,6 +896,7 @@ for (const testCase of sessionVersionTestCases) { const bCiphertext = await SignalClient.signalEncrypt( bMessage, aAddress, + bAddress, bobStores.session, bobStores.identity ); @@ -937,6 +942,7 @@ for (const testCase of sessionVersionTestCases) { const aliceStores = new TestStores(); const bobStores = new TestStores(); + const aAddress = SignalClient.ProtocolAddress.new('+14151111111', 1); const bAddress = SignalClient.ProtocolAddress.new('+19192222222', 1); const bPreKeyBundle = await testCase.makeBundle(bAddress, bobStores); @@ -957,6 +963,7 @@ for (const testCase of sessionVersionTestCases) { const aCiphertext = await SignalClient.signalEncrypt( aMessage, bAddress, + aAddress, aliceStores.session, aliceStores.identity, new Date('2020-01-01') @@ -975,6 +982,7 @@ for (const testCase of sessionVersionTestCases) { SignalClient.signalEncrypt( aMessage, bAddress, + aAddress, aliceStores.session, aliceStores.identity, new Date('2023-01-01') @@ -1007,6 +1015,7 @@ for (const testCase of sessionVersionTestCases) { const aCiphertext = await SignalClient.signalEncrypt( aMessage, bAddress, + aAddress, aliceStores.session, aliceStores.identity ); diff --git a/node/ts/test/protocol/SealedSenderTest.ts b/node/ts/test/protocol/SealedSenderTest.ts index 9776506f7..5f0586b5f 100644 --- a/node/ts/test/protocol/SealedSenderTest.ts +++ b/node/ts/test/protocol/SealedSenderTest.ts @@ -121,6 +121,7 @@ describe('SealedSender', () => { ); await bKyberStore.saveKyberPreKey(bKyberPrekeyId, bKyberPreKeyRecord); + const aAddress = SignalClient.ProtocolAddress.new(aUuid, aDeviceId); const bAddress = SignalClient.ProtocolAddress.new(bUuid, bDeviceId); await SignalClient.processPreKeyBundle( bPreKeyBundle, @@ -179,6 +180,7 @@ describe('SealedSender', () => { const innerMessage = await SignalClient.signalEncrypt( aPlaintext, bAddress, + aAddress, aSess, aKeys ); diff --git a/rust/bridge/shared/src/protocol.rs b/rust/bridge/shared/src/protocol.rs index 34cee26eb..f1156bdca 100644 --- a/rust/bridge/shared/src/protocol.rs +++ b/rust/bridge/shared/src/protocol.rs @@ -1031,6 +1031,7 @@ async fn SessionBuilder_ProcessPreKeyBundle( async fn SessionCipher_EncryptMessage( ptext: &[u8], protocol_address: &ProtocolAddress, + local_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_key_store: &mut dyn IdentityKeyStore, now: Timestamp, @@ -1039,6 +1040,7 @@ async fn SessionCipher_EncryptMessage( message_encrypt( ptext, protocol_address, + local_address, session_store, identity_key_store, now.into(), diff --git a/rust/protocol/benches/session.rs b/rust/protocol/benches/session.rs index 15163e22d..4c65c1fac 100644 --- a/rust/protocol/benches/session.rs +++ b/rust/protocol/benches/session.rs @@ -32,9 +32,14 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr .now_or_never() .expect("sync")?; - let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync")?; + let message_to_decrypt = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync")?; assert_eq!( message_to_decrypt.message_type(), CiphertextMessageType::Whisper @@ -64,9 +69,14 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr .now_or_never() .expect("sync")?; - let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync")?; + let message_to_decrypt = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync")?; assert_eq!( message_to_decrypt.message_type(), CiphertextMessageType::Whisper @@ -74,10 +84,15 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr c.bench_function("encrypting on an existing chain", |b| { b.iter(|| { - support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync") - .expect("success"); + support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync") + .expect("success"); }) }); c.bench_function("decrypting on an existing chain", |b| { @@ -213,9 +228,14 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr let original_message_to_decrypt = message_to_decrypt; // ...send another message to archive on Bob's side... - let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync")?; + let message_to_decrypt = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync")?; assert_eq!( message_to_decrypt.message_type(), CiphertextMessageType::PreKey, @@ -248,9 +268,14 @@ pub fn session_encrypt_result(c: &mut Criterion) -> Result<(), SignalProtocolErr .now_or_never() .expect("sync")?; // ...and prepare another message to benchmark decrypting. - let message_to_decrypt = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync")?; + let message_to_decrypt = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync")?; assert_eq!( message_to_decrypt.message_type(), CiphertextMessageType::PreKey, @@ -318,10 +343,15 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro .expect("sync")?; // Get the pre-key message out of the way. - let ctext = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync") - .expect("success"); + let ctext = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync") + .expect("success"); let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") @@ -329,10 +359,15 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro c.bench_function("session encrypt+decrypt 1 way", |b| { b.iter(|| { - let ctext = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync") - .expect("success"); + let ctext = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync") + .expect("success"); let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") @@ -342,19 +377,29 @@ pub fn session_encrypt_decrypt_result(c: &mut Criterion) -> Result<(), SignalPro c.bench_function("session encrypt+decrypt ping pong", |b| { b.iter(|| { - let ctext = support::encrypt(&mut alice_store, &bob_address, "a short message") - .now_or_never() - .expect("sync") - .expect("success"); + let ctext = support::encrypt( + &mut alice_store, + &bob_address, + &alice_address, + "a short message", + ) + .now_or_never() + .expect("sync") + .expect("success"); let _ptext = support::decrypt(&mut bob_store, &alice_address, &bob_address, &ctext) .now_or_never() .expect("sync") .expect("success"); - let ctext = support::encrypt(&mut bob_store, &alice_address, "a short message") - .now_or_never() - .expect("sync") - .expect("success"); + let ctext = support::encrypt( + &mut bob_store, + &alice_address, + &bob_address, + "a short message", + ) + .now_or_never() + .expect("sync") + .expect("success"); let _ptext = support::decrypt(&mut alice_store, &bob_address, &alice_address, &ctext) .now_or_never() .expect("sync") diff --git a/rust/protocol/cross-version-testing/src/current.rs b/rust/protocol/cross-version-testing/src/current.rs index 88bc5f0a6..a702d99de 100644 --- a/rust/protocol/cross-version-testing/src/current.rs +++ b/rust/protocol/cross-version-testing/src/current.rs @@ -143,10 +143,16 @@ impl super::LibSignalProtocolStore for LibSignalProtocolCurrent { .expect("can process pre-key bundles") } - fn encrypt(&mut self, remote: &str, msg: &[u8]) -> (Vec, CiphertextMessageType) { + fn encrypt( + &mut self, + remote: &str, + local: &str, + msg: &[u8], + ) -> (Vec, CiphertextMessageType) { let encrypted = message_encrypt( msg, &address(remote), + &address(local), &mut self.0.session_store, &mut self.0.identity_store, SystemTime::now(), diff --git a/rust/protocol/cross-version-testing/src/lib.rs b/rust/protocol/cross-version-testing/src/lib.rs index 58012ae95..e3bb6e2db 100644 --- a/rust/protocol/cross-version-testing/src/lib.rs +++ b/rust/protocol/cross-version-testing/src/lib.rs @@ -13,7 +13,12 @@ pub trait LibSignalProtocolStore { fn version(&self) -> &'static str; fn create_pre_key_bundle(&mut self) -> PreKeyBundle; fn process_pre_key_bundle(&mut self, remote: &str, pre_key_bundle: PreKeyBundle); - fn encrypt(&mut self, remote: &str, msg: &[u8]) -> (Vec, CiphertextMessageType); + fn encrypt( + &mut self, + remote: &str, + local: &str, + msg: &[u8], + ) -> (Vec, CiphertextMessageType); fn decrypt( &mut self, remote: &str, diff --git a/rust/protocol/cross-version-testing/src/v70.rs b/rust/protocol/cross-version-testing/src/v70.rs index 7400cc1a7..df62a3ea6 100644 --- a/rust/protocol/cross-version-testing/src/v70.rs +++ b/rust/protocol/cross-version-testing/src/v70.rs @@ -170,7 +170,12 @@ impl super::LibSignalProtocolStore for LibSignalProtocolV70 { .expect("can process pre-key bundles") } - fn encrypt(&mut self, remote: &str, msg: &[u8]) -> (Vec, super::CiphertextMessageType) { + fn encrypt( + &mut self, + remote: &str, + _local: &str, + msg: &[u8], + ) -> (Vec, super::CiphertextMessageType) { let encrypted = message_encrypt( msg, &address(remote), diff --git a/rust/protocol/cross-version-testing/tests/session.rs b/rust/protocol/cross-version-testing/tests/session.rs index 8415152c2..61a398169 100644 --- a/rust/protocol/cross-version-testing/tests/session.rs +++ b/rust/protocol/cross-version-testing/tests/session.rs @@ -21,7 +21,7 @@ fn test_basic_prekey() { let original_message = "L'homme est condamné à être libre".as_bytes(); let (outgoing_message, outgoing_message_type) = - alice_store.encrypt(bob_name, original_message); + alice_store.encrypt(bob_name, alice_name, original_message); assert_eq!(outgoing_message_type, CiphertextMessageType::PreKey); let ptext = bob_store.decrypt( @@ -33,7 +33,8 @@ fn test_basic_prekey() { assert_eq!(&ptext, original_message); let bobs_response = "Who watches the watchers?".as_bytes(); - let (bob_outgoing, bob_outgoing_type) = bob_store.encrypt(alice_name, bobs_response); + let (bob_outgoing, bob_outgoing_type) = + bob_store.encrypt(alice_name, bob_name, bobs_response); assert_eq!(bob_outgoing_type, CiphertextMessageType::Whisper); let alice_decrypts = @@ -52,7 +53,8 @@ fn run_interaction( ) { let alice_ptext = b"It's rabbit season"; - let (alice_message, alice_message_type) = alice_store.encrypt(bob_name, alice_ptext); + let (alice_message, alice_message_type) = + alice_store.encrypt(bob_name, alice_name, alice_ptext); assert_eq!(alice_message_type, CiphertextMessageType::Whisper); assert_eq!( &bob_store.decrypt(alice_name, bob_name, &alice_message, alice_message_type), @@ -61,7 +63,7 @@ fn run_interaction( let bob_ptext = b"It's duck season"; - let (bob_message, bob_message_type) = bob_store.encrypt(alice_name, bob_ptext); + let (bob_message, bob_message_type) = bob_store.encrypt(alice_name, bob_name, bob_ptext); assert_eq!(bob_message_type, CiphertextMessageType::Whisper); assert_eq!( &alice_store.decrypt(bob_name, alice_name, &bob_message, bob_message_type), @@ -71,7 +73,7 @@ fn run_interaction( for i in 0..10 { let alice_ptext = format!("A->B message {}", i); let (alice_message, alice_message_type) = - alice_store.encrypt(bob_name, alice_ptext.as_bytes()); + alice_store.encrypt(bob_name, alice_name, alice_ptext.as_bytes()); assert_eq!(alice_message_type, CiphertextMessageType::Whisper); assert_eq!( &bob_store.decrypt(alice_name, bob_name, &alice_message, alice_message_type), @@ -81,7 +83,8 @@ fn run_interaction( for i in 0..10 { let bob_ptext = format!("B->A message {}", i); - let (bob_message, bob_message_type) = bob_store.encrypt(alice_name, bob_ptext.as_bytes()); + let (bob_message, bob_message_type) = + bob_store.encrypt(alice_name, bob_name, bob_ptext.as_bytes()); assert_eq!(bob_message_type, CiphertextMessageType::Whisper); assert_eq!( &alice_store.decrypt(bob_name, alice_name, &bob_message, bob_message_type), @@ -93,13 +96,13 @@ fn run_interaction( for i in 0..10 { let alice_ptext = format!("A->B OOO message {}", i); - let (alice_message, _) = alice_store.encrypt(bob_name, alice_ptext.as_bytes()); + let (alice_message, _) = alice_store.encrypt(bob_name, alice_name, alice_ptext.as_bytes()); alice_ooo_messages.push((alice_ptext, alice_message)); } for i in 0..10 { let alice_ptext = format!("A->B post-OOO message {}", i); - let (alice_message, _) = alice_store.encrypt(bob_name, alice_ptext.as_bytes()); + let (alice_message, _) = alice_store.encrypt(bob_name, alice_name, alice_ptext.as_bytes()); assert_eq!( &bob_store.decrypt( alice_name, @@ -113,7 +116,7 @@ fn run_interaction( for i in 0..10 { let bob_ptext = format!("B->A message post-OOO {}", i); - let (bob_message, _) = bob_store.encrypt(alice_name, bob_ptext.as_bytes()); + let (bob_message, _) = bob_store.encrypt(alice_name, bob_name, bob_ptext.as_bytes()); assert_eq!( &alice_store.decrypt( bob_name, diff --git a/rust/protocol/fuzz/fuzz_targets/interaction.rs b/rust/protocol/fuzz/fuzz_targets/interaction.rs index fa7e43492..6cb8c4892 100644 --- a/rust/protocol/fuzz/fuzz_targets/interaction.rs +++ b/rust/protocol/fuzz/fuzz_targets/interaction.rs @@ -173,6 +173,7 @@ impl Participant { let outgoing_message = message_encrypt( &buffer, &them.address, + &self.address, &mut self.store.session_store, &mut self.store.identity_store, SystemTime::UNIX_EPOCH, diff --git a/rust/protocol/src/proto/wire.proto b/rust/protocol/src/proto/wire.proto index 4e2647f0e..c1b074633 100644 --- a/rust/protocol/src/proto/wire.proto +++ b/rust/protocol/src/proto/wire.proto @@ -13,7 +13,7 @@ message SignalMessage { optional uint32 previous_counter = 3; optional bytes ciphertext = 4; optional bytes pq_ratchet = 5; - optional bytes recipient_address = 6; + optional bytes addresses = 6; } message PreKeySignalMessage { diff --git a/rust/protocol/src/protocol.rs b/rust/protocol/src/protocol.rs index 67c756bac..5531b9301 100644 --- a/rust/protocol/src/protocol.rs +++ b/rust/protocol/src/protocol.rs @@ -68,7 +68,7 @@ pub struct SignalMessage { previous_counter: u32, ciphertext: Box<[u8]>, pq_ratchet: spqr::SerializedState, - recipient_address: Option>, + addresses: Option>, serialized: Box<[u8]>, } @@ -79,7 +79,7 @@ impl SignalMessage { pub fn new( message_version: u8, mac_key: &[u8], - recipient_address: Option<&ProtocolAddress>, + addresses: Option<(&ProtocolAddress, &ProtocolAddress)>, sender_ratchet_key: PublicKey, counter: u32, previous_counter: u32, @@ -88,6 +88,8 @@ impl SignalMessage { receiver_identity_key: &IdentityKey, pq_ratchet: &[u8], ) -> Result { + let addresses = + addresses.and_then(|(sender, recipient)| Self::serialize_addresses(sender, recipient)); let message = proto::wire::SignalMessage { ratchet_key: Some(sender_ratchet_key.serialize().into_vec()), counter: Some(counter), @@ -98,7 +100,7 @@ impl SignalMessage { } else { Some(pq_ratchet.to_vec()) }, - recipient_address: recipient_address.and_then(Self::serialize_recipient_address), + addresses, }; let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH); serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION); @@ -120,7 +122,7 @@ impl SignalMessage { previous_counter, ciphertext: ciphertext.into(), pq_ratchet: pq_ratchet.to_vec(), - recipient_address: message.recipient_address.map(Into::into), + addresses: message.addresses.map(Into::into), serialized, }) } @@ -181,8 +183,9 @@ impl SignalMessage { Ok(true) } - pub fn verify_mac_with_recipient_address( + pub fn verify_mac_with_addresses( &self, + sender_address: &ProtocolAddress, recipient_address: &ProtocolAddress, sender_identity_key: &IdentityKey, receiver_identity_key: &IdentityKey, @@ -192,22 +195,29 @@ impl SignalMessage { return Ok(false); } - // If the sender didn't include a recipient address, accept the message for + // If the sender didn't include addresses, accept the message for // backward compatibility with older clients. - let Some(encoded_recipient_address) = &self.recipient_address else { + let Some(encoded_addresses) = &self.addresses else { return Ok(true); }; - // Only match valid Service IDs. - let Some(expected) = Self::serialize_recipient_address(recipient_address) else { - log::warn!("Local address not a valid Service ID {}", recipient_address); + let Some(expected) = Self::serialize_addresses(sender_address, recipient_address) else { + log::warn!( + "Local addresses not valid Service IDs: sender={}, recipient={}", + sender_address, + recipient_address, + ); return Ok(false); }; - if bool::from(expected.ct_eq(encoded_recipient_address.as_ref())) { + if bool::from(expected.ct_eq(encoded_addresses.as_ref())) { Ok(true) } else { - log::warn!("Recipient address mismatch for {}", recipient_address); + log::warn!( + "Address mismatch: sender={}, recipient={}", + sender_address, + recipient_address, + ); Ok(false) } } @@ -235,12 +245,20 @@ impl SignalMessage { Ok(result) } - /// Serializes the recipient address to Service-Id-Fixed-Width-Binary (17 bytes) + device ID - /// (1 byte). Returns `None` if the address name is not a valid ServiceId. - fn serialize_recipient_address(recipient_address: &ProtocolAddress) -> Option> { - let service_id = ServiceId::parse_from_service_id_string(recipient_address.name())?; - let mut bytes = service_id.service_id_fixed_width_binary().to_vec(); - bytes.push(recipient_address.device_id().into()); + /// Serializes sender and recipient addresses into a single byte vector. + /// Returns `None` if either address name is not a valid ServiceId. + fn serialize_addresses( + sender: &ProtocolAddress, + recipient: &ProtocolAddress, + ) -> Option> { + let sender_service_id = ServiceId::parse_from_service_id_string(sender.name())?; + let recipient_service_id = ServiceId::parse_from_service_id_string(recipient.name())?; + + let mut bytes = Vec::with_capacity(36); + bytes.extend_from_slice(&sender_service_id.service_id_fixed_width_binary()); + bytes.push(sender.device_id().into()); + bytes.extend_from_slice(&recipient_service_id.service_id_fixed_width_binary()); + bytes.push(recipient.device_id().into()); Some(bytes) } } @@ -294,7 +312,7 @@ impl TryFrom<&[u8]> for SignalMessage { previous_counter, ciphertext, pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]), - recipient_address: proto_structure.recipient_address.map(Into::into), + addresses: proto_structure.addresses.map(Into::into), serialized: Box::from(value), }) } @@ -986,13 +1004,19 @@ mod tests { let sender_ratchet_key_pair = KeyPair::generate(csprng); let sender_identity_key_pair = KeyPair::generate(csprng); let receiver_identity_key_pair = KeyPair::generate(csprng); - let recipient_address = - ProtocolAddress::new("recipient".to_owned(), DeviceId::new(1).unwrap()); + let sender_address = ProtocolAddress::new( + "31415926-5358-9793-2384-626433827950".to_owned(), + DeviceId::new(1).unwrap(), + ); + let recipient_address = ProtocolAddress::new( + "27182818-2845-9045-2353-602874713526".to_owned(), + DeviceId::new(1).unwrap(), + ); SignalMessage::new( 4, &mac_key, - Some(&recipient_address), + Some((&sender_address, &recipient_address)), sender_ratchet_key_pair.public_key, 42, 41, @@ -1009,7 +1033,7 @@ mod tests { assert_eq!(m1.counter, m2.counter); assert_eq!(m1.previous_counter, m2.previous_counter); assert_eq!(m1.ciphertext, m2.ciphertext); - assert_eq!(m1.recipient_address, m2.recipient_address); + assert_eq!(m1.addresses, m2.addresses); assert_eq!(m1.serialized, m2.serialized); } @@ -1078,8 +1102,7 @@ mod tests { } #[test] - fn test_signal_message_verify_mac_accepts_legacy_message_without_recipient_address() - -> Result<()> { + fn test_signal_message_verify_mac_accepts_legacy_message_without_addresses() -> Result<()> { let mut csprng = OsRng.unwrap_err(); let mut mac_key = [0u8; 32]; csprng.fill_bytes(&mut mac_key); @@ -1090,15 +1113,19 @@ mod tests { let sender_ratchet_key_pair = KeyPair::generate(&mut csprng); let sender_identity_key_pair = KeyPair::generate(&mut csprng); let receiver_identity_key_pair = KeyPair::generate(&mut csprng); + let sender_address = ProtocolAddress::new( + "16180339-8874-9894-8482-045868343656".to_owned(), + DeviceId::new(1).unwrap(), + ); let recipient_address = ProtocolAddress::new( - "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + "14142135-6237-3095-0488-016887242096".to_owned(), DeviceId::new(1).unwrap(), ); let message = SignalMessage::new( 4, &mac_key, - Some(&recipient_address), + Some((&sender_address, &recipient_address)), sender_ratchet_key_pair.public_key, 42, 41, @@ -1112,7 +1139,7 @@ mod tests { &message.serialized()[1..message.serialized().len() - SignalMessage::MAC_LENGTH], ) .expect("valid protobuf"); - proto_structure.recipient_address = None; + proto_structure.addresses = None; let mut serialized = vec![((message.message_version() & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION]; @@ -1126,7 +1153,8 @@ mod tests { serialized.extend_from_slice(&mac); let legacy_message = SignalMessage::try_from(serialized.as_slice())?; - assert!(legacy_message.verify_mac_with_recipient_address( + assert!(legacy_message.verify_mac_with_addresses( + &sender_address, &recipient_address, &sender_identity_key_pair.public_key.into(), &receiver_identity_key_pair.public_key.into(), @@ -1137,7 +1165,7 @@ mod tests { } #[test] - fn test_signal_message_verify_mac_rejects_wrong_recipient_address() -> Result<()> { + fn test_signal_message_verify_mac_rejects_wrong_address() -> Result<()> { let mut csprng = OsRng.unwrap_err(); let mut mac_key = [0u8; 32]; csprng.fill_bytes(&mut mac_key); @@ -1148,19 +1176,23 @@ mod tests { let sender_ratchet_key_pair = KeyPair::generate(&mut csprng); let sender_identity_key_pair = KeyPair::generate(&mut csprng); let receiver_identity_key_pair = KeyPair::generate(&mut csprng); - let recipient_address = ProtocolAddress::new( - "9d0652a3-dcc3-4d11-975f-74d61598733f".to_owned(), + let sender_address = ProtocolAddress::new( + "deadbeef-cafe-babe-feed-faceb00c0ffe".to_owned(), DeviceId::new(1).unwrap(), ); - let wrong_recipient_address = ProtocolAddress::new( - "a5e2f8d1-4b3c-4e7a-8f9d-1c2b3d4e5f6a".to_owned(), + let recipient_address = ProtocolAddress::new( + "01120358-1321-3455-0891-44233377610a".to_owned(), + DeviceId::new(1).unwrap(), + ); + let wrong_address = ProtocolAddress::new( + "02030507-1113-1719-2329-313741434753".to_owned(), DeviceId::new(1).unwrap(), ); let message = SignalMessage::new( 4, &mac_key, - Some(&recipient_address), + Some((&sender_address, &recipient_address)), sender_ratchet_key_pair.public_key, 42, 41, @@ -1170,8 +1202,28 @@ mod tests { b"", )?; - assert!(!message.verify_mac_with_recipient_address( - &wrong_recipient_address, + // Wrong sender address should be rejected. + assert!(!message.verify_mac_with_addresses( + &wrong_address, + &recipient_address, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + &mac_key, + )?); + + // Wrong recipient address should be rejected. + assert!(!message.verify_mac_with_addresses( + &sender_address, + &wrong_address, + &sender_identity_key_pair.public_key.into(), + &receiver_identity_key_pair.public_key.into(), + &mac_key, + )?); + + // Correct addresses should be accepted. + assert!(message.verify_mac_with_addresses( + &sender_address, + &recipient_address, &sender_identity_key_pair.public_key.into(), &receiver_identity_key_pair.public_key.into(), &mac_key, diff --git a/rust/protocol/src/sealed_sender.rs b/rust/protocol/src/sealed_sender.rs index a00d128a9..c2d44b587 100644 --- a/rust/protocol/src/sealed_sender.rs +++ b/rust/protocol/src/sealed_sender.rs @@ -894,8 +894,20 @@ pub async fn sealed_sender_encrypt( now: SystemTime, rng: &mut R, ) -> Result> { - let message = - message_encrypt(ptext, destination, session_store, identity_store, now, rng).await?; + let sender_address = ProtocolAddress::new( + sender_cert.sender_uuid()?.to_owned(), + sender_cert.sender_device_id()?, + ); + let message = message_encrypt( + ptext, + destination, + &sender_address, + session_store, + identity_store, + now, + rng, + ) + .await?; let usmc = UnidentifiedSenderMessageContent::new( message.message_type(), sender_cert.clone(), diff --git a/rust/protocol/src/session_cipher.rs b/rust/protocol/src/session_cipher.rs index 9347aeeb9..d43a98d7e 100644 --- a/rust/protocol/src/session_cipher.rs +++ b/rust/protocol/src/session_cipher.rs @@ -19,6 +19,7 @@ use crate::{ pub async fn message_encrypt( ptext: &[u8], remote_address: &ProtocolAddress, + local_address: &ProtocolAddress, session_store: &mut dyn SessionStore, identity_store: &mut dyn IdentityKeyStore, now: SystemTime, @@ -92,7 +93,7 @@ pub async fn message_encrypt( let message = SignalMessage::new( session_version, message_keys.mac_key(), - Some(remote_address), + Some((local_address, remote_address)), sender_ephemeral, chain_key.index(), previous_counter, @@ -667,7 +668,8 @@ fn decrypt_message_with_state( ))?; let mac_valid = match local_address { - Some(local_address) => ciphertext.verify_mac_with_recipient_address( + Some(local_address) => ciphertext.verify_mac_with_addresses( + remote_address, local_address, &their_identity_key, &state.local_identity_key()?, diff --git a/rust/protocol/test-support/src/lib.rs b/rust/protocol/test-support/src/lib.rs index c4093b6b8..fef9b255f 100644 --- a/rust/protocol/test-support/src/lib.rs +++ b/rust/protocol/test-support/src/lib.rs @@ -241,6 +241,7 @@ impl Participant { let outgoing_message = message_encrypt( &buffer, &them.address, + &self.address, &mut self.state.store.session_store, &mut self.state.store.identity_store, SystemTime::UNIX_EPOCH, diff --git a/rust/protocol/tests/sealed_sender.rs b/rust/protocol/tests/sealed_sender.rs index d8297d782..bfe1af137 100644 --- a/rust/protocol/tests/sealed_sender.rs +++ b/rust/protocol/tests/sealed_sender.rs @@ -451,6 +451,7 @@ fn test_sealed_sender_multi_recipient() -> Result<(), SignalProtocolError> { let alice_uuid = "9d0652a3-dcc3-4d11-975f-74d61598733f".to_string(); let bob_uuid = "796abedb-ca4e-4f18-8803-1fde5b921f9f".to_string(); + let alice_uuid_address = ProtocolAddress::new(alice_uuid.clone(), alice_device_id); let bob_uuid_address = ProtocolAddress::new(bob_uuid.clone(), bob_device_id); let mut alice_store = support::test_in_memory_protocol_store()?; @@ -493,6 +494,7 @@ fn test_sealed_sender_multi_recipient() -> Result<(), SignalProtocolError> { let alice_message = message_encrypt( &alice_ptext, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::now(), @@ -561,6 +563,7 @@ fn test_sealed_sender_multi_recipient() -> Result<(), SignalProtocolError> { let alice_message = message_encrypt( &alice_ptext, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::now(), @@ -621,6 +624,7 @@ fn test_sealed_sender_multi_recipient() -> Result<(), SignalProtocolError> { let alice_message = message_encrypt( &alice_ptext, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::now(), @@ -698,6 +702,7 @@ fn test_sealed_sender_multi_recipient_encrypt_with_archived_session() let alice_uuid = "9d0652a3-dcc3-4d11-975f-74d61598733f".to_string(); let bob_uuid = "796abedb-ca4e-4f18-8803-1fde5b921f9f".to_string(); + let alice_uuid_address = ProtocolAddress::new(alice_uuid.clone(), alice_device_id); let bob_uuid_address = ProtocolAddress::new(bob_uuid.clone(), bob_device_id); let mut alice_store = support::test_in_memory_protocol_store()?; @@ -740,6 +745,7 @@ fn test_sealed_sender_multi_recipient_encrypt_with_archived_session() let alice_message = message_encrypt( &alice_ptext, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::now(), @@ -803,6 +809,7 @@ fn test_sealed_sender_multi_recipient_encrypt_with_bad_registration_id() let alice_uuid = "9d0652a3-dcc3-4d11-975f-74d61598733f".to_string(); let bob_uuid = "796abedb-ca4e-4f18-8803-1fde5b921f9f".to_string(); + let alice_uuid_address = ProtocolAddress::new(alice_uuid.clone(), alice_device_id); let bob_uuid_address = ProtocolAddress::new(bob_uuid.clone(), bob_device_id); let mut alice_store = support::test_in_memory_protocol_store()?; @@ -846,6 +853,7 @@ fn test_sealed_sender_multi_recipient_encrypt_with_bad_registration_id() let alice_message = message_encrypt( &alice_ptext, &bob_uuid_address, + &alice_uuid_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::now(), @@ -926,6 +934,7 @@ fn test_decryption_error_in_sealed_sender() -> Result<(), SignalProtocolError> { let bob_first_message = message_encrypt( b"swim camp", &alice_uuid_address, + &bob_uuid_address, &mut bob_store.session_store, &mut bob_store.identity_store, SystemTime::now(), @@ -951,6 +960,7 @@ fn test_decryption_error_in_sealed_sender() -> Result<(), SignalProtocolError> { let bob_message = message_encrypt( b"space camp", &alice_uuid_address, + &bob_uuid_address, &mut bob_store.session_store, &mut bob_store.identity_store, SystemTime::now(), diff --git a/rust/protocol/tests/session.rs b/rust/protocol/tests/session.rs index 7ae25bfdc..806d193db 100644 --- a/rust/protocol/tests/session.rs +++ b/rust/protocol/tests/session.rs @@ -77,7 +77,7 @@ fn test_basic_prekey() -> TestResult { let original_message = "L'homme est condamné à être libre"; - let outgoing_message = encrypt(alice_store, &bob_address, original_message).await?; + let outgoing_message = encrypt(alice_store, &bob_address, &alice_address, original_message).await?; assert_eq!( outgoing_message.message_type(), @@ -127,7 +127,7 @@ fn test_basic_prekey() -> TestResult { assert_eq!(bobs_session_with_alice.alice_base_key()?.len(), 32 + 1); let bob_outgoing = - encrypt(&mut bob_store_builder.store, &alice_address, bobs_response).await?; + encrypt(&mut bob_store_builder.store, &alice_address, &bob_address, bobs_response).await?; assert_eq!(bob_outgoing.message_type(), CiphertextMessageType::Whisper); @@ -170,7 +170,8 @@ fn test_basic_prekey() -> TestResult { .await?; let outgoing_message = - encrypt(&mut alter_alice_store, &bob_address, original_message).await?; + encrypt(&mut alter_alice_store, &bob_address, &alice_address, original_message) + .await?; assert!(matches!( decrypt(&mut bob_store_builder.store, &alice_address, &bob_address, &outgoing_message) @@ -280,11 +281,22 @@ fn test_chain_jump_over_limit() -> TestResult { pub const MAX_FORWARD_JUMPS: usize = 25_000; for _i in 0..(MAX_FORWARD_JUMPS + 1) { - let _msg = - encrypt(alice_store, &bob_address, "Yet another message for you").await?; + let _msg = encrypt( + alice_store, + &bob_address, + &alice_address, + "Yet another message for you", + ) + .await?; } - let too_far = encrypt(alice_store, &bob_address, "Now you have gone too far").await?; + let too_far = encrypt( + alice_store, + &bob_address, + &alice_address, + "Now you have gone too far", + ) + .await?; assert!( decrypt( @@ -345,12 +357,22 @@ fn test_chain_jump_over_limit_with_self() -> TestResult { pub const MAX_FORWARD_JUMPS: usize = 25_000; for _i in 0..(MAX_FORWARD_JUMPS + 1) { - let _msg = - encrypt(a1_store, &a2_address, "Yet another message for yourself").await?; + let _msg = encrypt( + a1_store, + &a2_address, + &a1_address, + "Yet another message for yourself", + ) + .await?; } - let too_far = - encrypt(a1_store, &a2_address, "This is the song that never ends").await?; + let too_far = encrypt( + a1_store, + &a2_address, + &a1_address, + "This is the song that never ends", + ) + .await?; let ptext = decrypt( &mut a2_store_builder.store, @@ -484,8 +506,10 @@ fn test_repeat_bundle_message() -> TestResult { let original_message = "L'homme est condamné à être libre"; - let outgoing_message1 = encrypt(alice_store, &bob_address, original_message).await?; - let outgoing_message2 = encrypt(alice_store, &bob_address, original_message).await?; + let outgoing_message1 = + encrypt(alice_store, &bob_address, &alice_address, original_message).await?; + let outgoing_message2 = + encrypt(alice_store, &bob_address, &alice_address, original_message).await?; assert_eq!( outgoing_message1.message_type(), @@ -515,6 +539,7 @@ fn test_repeat_bundle_message() -> TestResult { let bob_outgoing = encrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, original_message, ) .await?; @@ -547,6 +572,7 @@ fn test_repeat_bundle_message() -> TestResult { let bob_outgoing = encrypt( &mut bob_store_builder.store, &alice_address, + &bob_address, original_message, ) .await?; @@ -618,7 +644,8 @@ fn test_bad_message_bundle() -> TestResult { let original_message = "L'homme est condamné à être libre"; assert!(bob_store.get_pre_key(pre_key_id).await.is_ok()); - let outgoing_message = encrypt(alice_store, &bob_address, original_message).await?; + let outgoing_message = + encrypt(alice_store, &bob_address, &alice_address, original_message).await?; assert_eq!( outgoing_message.message_type(), @@ -711,7 +738,8 @@ fn test_optional_one_time_prekey() -> TestResult { let original_message = "L'homme est condamné à être libre"; - let outgoing_message = encrypt(alice_store, &bob_address, original_message).await?; + let outgoing_message = + encrypt(alice_store, &bob_address, &alice_address, original_message).await?; assert_eq!( outgoing_message.message_type(), @@ -781,7 +809,13 @@ fn test_message_key_limits() -> TestResult { for i in 0..TOO_MANY_MESSAGES { inflight.push( - encrypt(&mut alice_store, &bob_address, &format!("It's over {i}")).await?, + encrypt( + &mut alice_store, + &bob_address, + &alice_address, + &format!("It's over {i}"), + ) + .await?, ); } @@ -885,8 +919,10 @@ fn test_basic_simultaneous_initiate() -> TestResult { ) .await?; - let message_for_bob = encrypt(alice_store, &bob_address, "hi bob").await?; - let message_for_alice = encrypt(bob_store, &alice_address, "hi alice").await?; + let message_for_bob = + encrypt(alice_store, &bob_address, &alice_address, "hi bob").await?; + let message_for_alice = + encrypt(bob_store, &alice_address, &bob_address, "hi alice").await?; assert_eq!( message_for_bob.message_type(), @@ -942,7 +978,8 @@ fn test_basic_simultaneous_initiate() -> TestResult { !is_session_id_equal(alice_store, &alice_address, bob_store, &bob_address).await? ); - let alice_response = encrypt(alice_store, &bob_address, "nice to see you").await?; + let alice_response = + encrypt(alice_store, &bob_address, &alice_address, "nice to see you").await?; assert_eq!( alice_response.message_type(), @@ -967,7 +1004,8 @@ fn test_basic_simultaneous_initiate() -> TestResult { is_session_id_equal(alice_store, &alice_address, bob_store, &bob_address).await? ); - let bob_response = encrypt(bob_store, &alice_address, "you as well").await?; + let bob_response = + encrypt(bob_store, &alice_address, &bob_address, "you as well").await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1055,8 +1093,10 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { ) .await?; - let message_for_bob = encrypt(alice_store, &bob_address, "hi bob").await?; - let message_for_alice = encrypt(bob_store, &alice_address, "hi alice").await?; + let message_for_bob = + encrypt(alice_store, &bob_address, &alice_address, "hi bob").await?; + let message_for_alice = + encrypt(bob_store, &alice_address, &bob_address, "hi alice").await?; assert_eq!( message_for_bob.message_type(), @@ -1094,7 +1134,8 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { expected_session_version ); - let alice_response = encrypt(alice_store, &bob_address, "nice to see you").await?; + let alice_response = + encrypt(alice_store, &bob_address, &alice_address, "nice to see you").await?; assert_eq!(alice_response.message_type(), CiphertextMessageType::PreKey); @@ -1116,7 +1157,8 @@ fn test_simultaneous_initiate_with_lossage() -> TestResult { is_session_id_equal(alice_store, &alice_address, bob_store, &bob_address).await? ); - let bob_response = encrypt(bob_store, &alice_address, "you as well").await?; + let bob_response = + encrypt(bob_store, &alice_address, &bob_address, "you as well").await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1204,8 +1246,10 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { ) .await?; - let message_for_bob = encrypt(alice_store, &bob_address, "hi bob").await?; - let message_for_alice = encrypt(bob_store, &alice_address, "hi alice").await?; + let message_for_bob = + encrypt(alice_store, &bob_address, &alice_address, "hi bob").await?; + let message_for_alice = + encrypt(bob_store, &alice_address, &bob_address, "hi alice").await?; assert_eq!( message_for_bob.message_type(), @@ -1261,7 +1305,8 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { !is_session_id_equal(alice_store, &alice_address, bob_store, &bob_address).await? ); - let alice_response = encrypt(alice_store, &bob_address, "nice to see you").await?; + let alice_response = + encrypt(alice_store, &bob_address, &alice_address, "nice to see you").await?; assert_eq!( alice_response.message_type(), @@ -1272,7 +1317,8 @@ fn test_simultaneous_initiate_lost_message() -> TestResult { !is_session_id_equal(alice_store, &alice_address, bob_store, &bob_address).await? ); - let bob_response = encrypt(bob_store, &alice_address, "you as well").await?; + let bob_response = + encrypt(bob_store, &alice_address, &bob_address, "you as well").await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1358,10 +1404,20 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { ) .await?; - let message_for_bob = - encrypt(&mut alice_store_builder.store, &bob_address, "hi bob").await?; - let message_for_alice = - encrypt(&mut bob_store_builder.store, &alice_address, "hi alice").await?; + let message_for_bob = encrypt( + &mut alice_store_builder.store, + &bob_address, + &alice_address, + "hi bob", + ) + .await?; + let message_for_alice = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "hi alice", + ) + .await?; assert_eq!( message_for_bob.message_type(), @@ -1431,10 +1487,20 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { } for _ in 0..50 { - let message_for_bob = - encrypt(&mut alice_store_builder.store, &bob_address, "hi bob").await?; - let message_for_alice = - encrypt(&mut bob_store_builder.store, &alice_address, "hi alice").await?; + let message_for_bob = encrypt( + &mut alice_store_builder.store, + &bob_address, + &alice_address, + "hi bob", + ) + .await?; + let message_for_alice = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "hi alice", + ) + .await?; assert_eq!( message_for_bob.message_type(), @@ -1506,6 +1572,7 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { let alice_response = encrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, "nice to see you", ) .await?; @@ -1525,8 +1592,13 @@ fn test_simultaneous_initiate_repeated_messages() -> TestResult { .await? ); - let bob_response = - encrypt(&mut bob_store_builder.store, &alice_address, "you as well").await?; + let bob_response = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "you as well", + ) + .await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1606,6 +1678,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let lost_message_for_bob = encrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, "it was so long ago", ) .await?; @@ -1639,10 +1712,20 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { ) .await?; - let message_for_bob = - encrypt(&mut alice_store_builder.store, &bob_address, "hi bob").await?; - let message_for_alice = - encrypt(&mut bob_store_builder.store, &alice_address, "hi alice").await?; + let message_for_bob = encrypt( + &mut alice_store_builder.store, + &bob_address, + &alice_address, + "hi bob", + ) + .await?; + let message_for_alice = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "hi alice", + ) + .await?; assert_eq!( message_for_bob.message_type(), @@ -1712,10 +1795,20 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { } for _ in 0..50 { - let message_for_bob = - encrypt(&mut alice_store_builder.store, &bob_address, "hi bob").await?; - let message_for_alice = - encrypt(&mut bob_store_builder.store, &alice_address, "hi alice").await?; + let message_for_bob = encrypt( + &mut alice_store_builder.store, + &bob_address, + &alice_address, + "hi bob", + ) + .await?; + let message_for_alice = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "hi alice", + ) + .await?; assert_eq!( message_for_bob.message_type(), @@ -1787,6 +1880,7 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { let alice_response = encrypt( &mut alice_store_builder.store, &bob_address, + &alice_address, "nice to see you", ) .await?; @@ -1806,8 +1900,13 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { .await? ); - let bob_response = - encrypt(&mut bob_store_builder.store, &alice_address, "you as well").await?; + let bob_response = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "you as well", + ) + .await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1859,8 +1958,13 @@ fn test_simultaneous_initiate_lost_message_repeated_messages() -> TestResult { .await? ); - let bob_response = - encrypt(&mut bob_store_builder.store, &alice_address, "so it was").await?; + let bob_response = encrypt( + &mut bob_store_builder.store, + &alice_address, + &bob_address, + "so it was", + ) + .await?; assert_eq!(bob_response.message_type(), CiphertextMessageType::Whisper); @@ -1936,7 +2040,13 @@ fn test_zero_is_a_valid_prekey_id() -> TestResult { let original_message = "L'homme est condamné à être libre"; - let outgoing_message = encrypt(&mut alice_store, &bob_address, original_message).await?; + let outgoing_message = encrypt( + &mut alice_store, + &bob_address, + &alice_address, + original_message, + ) + .await?; assert_eq!( outgoing_message.message_type(), @@ -1972,6 +2082,8 @@ fn test_unacknowledged_sessions_eventually_expire() -> TestResult { const WELL_PAST_EXPIRATION: Duration = Duration::from_secs(60 * 60 * 24 * 90); let mut csprng = OsRng.unwrap_err(); + let alice_address = + ProtocolAddress::new("+14151111111".to_owned(), DeviceId::new(1).unwrap()); let bob_address = ProtocolAddress::new("+14151111112".to_owned(), DeviceId::new(1).unwrap()); @@ -2029,6 +2141,7 @@ fn test_unacknowledged_sessions_eventually_expire() -> TestResult { let outgoing_message = message_encrypt( original_message.as_bytes(), &bob_address, + &alice_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::UNIX_EPOCH + Duration::from_secs(1), @@ -2067,6 +2180,7 @@ fn test_unacknowledged_sessions_eventually_expire() -> TestResult { let error = message_encrypt( original_message.as_bytes(), &bob_address, + &alice_address, &mut alice_store.session_store, &mut alice_store.identity_store, SystemTime::UNIX_EPOCH + WELL_PAST_EXPIRATION, @@ -2120,6 +2234,7 @@ fn prekey_message_failed_decryption_does_not_update_stores() -> TestResult { let message = message_encrypt( "from Bob".as_bytes(), &alice_address, + &bob_address, &mut bob_store.session_store, &mut bob_store.identity_store, SystemTime::UNIX_EPOCH, @@ -2224,7 +2339,7 @@ fn prekey_message_failed_decryption_does_not_update_stores_even_when_previously_ .expect("can receive bundle"); // Bob sends a message that decrypts just fine. - let bob_ciphertext = encrypt(&mut bob_store, &alice_address, "from Bob") + let bob_ciphertext = encrypt(&mut bob_store, &alice_address, &bob_address, "from Bob") .await .expect("valid"); _ = decrypt( @@ -2265,6 +2380,7 @@ fn prekey_message_failed_decryption_does_not_update_stores_even_when_previously_ let message = message_encrypt( "from Bob".as_bytes(), &alice_address, + &bob_address, &mut bob_store.session_store, &mut bob_store.identity_store, SystemTime::now(), @@ -2374,7 +2490,7 @@ fn prekey_message_to_archived_session() -> TestResult { .await .expect("can receive bundle"); - let bob_ciphertext = encrypt(&mut bob_store, &alice_address, "from Bob") + let bob_ciphertext = encrypt(&mut bob_store, &alice_address, &bob_address, "from Bob") .await .expect("valid"); assert_eq!(bob_ciphertext.message_type(), CiphertextMessageType::PreKey); @@ -2403,16 +2519,17 @@ fn prekey_message_to_archived_session() -> TestResult { .expect("can receive bundle"); // (This is technically unnecessary, the process_prekey_bundle is sufficient, but it's illustrative.) - let unsent_alice_ciphertext = encrypt(&mut alice_store, &bob_address, "from Alice") - .await - .expect("valid"); + let unsent_alice_ciphertext = + encrypt(&mut alice_store, &bob_address, &alice_address, "from Alice") + .await + .expect("valid"); assert_eq!( unsent_alice_ciphertext.message_type(), CiphertextMessageType::PreKey ); // But before Alice can send the message, she gets a second message from Bob. - let bob_ciphertext_2 = encrypt(&mut bob_store, &alice_address, "from Bob 2") + let bob_ciphertext_2 = encrypt(&mut bob_store, &alice_address, &bob_address, "from Bob 2") .await .expect("valid"); assert_eq!( @@ -2476,7 +2593,13 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec .await?; let alice_plaintext = "This is Alice's message"; - let alice_ciphertext = encrypt(&mut alice_store, &bob_address, alice_plaintext).await?; + let alice_ciphertext = encrypt( + &mut alice_store, + &bob_address, + &alice_address, + alice_plaintext, + ) + .await?; let bob_decrypted = decrypt( &mut bob_store, &alice_address, @@ -2491,7 +2614,8 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec let bob_plaintext = "This is Bob's reply"; - let bob_ciphertext = encrypt(&mut bob_store, &alice_address, bob_plaintext).await?; + let bob_ciphertext = + encrypt(&mut bob_store, &alice_address, &bob_address, bob_plaintext).await?; let alice_decrypted = decrypt( &mut alice_store, &bob_address, @@ -2511,7 +2635,7 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec for i in 0..ALICE_MESSAGE_COUNT { let ptext = format!("смерть за смерть {i}"); - let ctext = encrypt(&mut alice_store, &bob_address, &ptext).await?; + let ctext = encrypt(&mut alice_store, &bob_address, &alice_address, &ptext).await?; alice_messages.push((ptext, ctext)); } @@ -2537,7 +2661,7 @@ fn run_session_interaction(alice_session: SessionRecord, bob_session: SessionRec for i in 0..BOB_MESSAGE_COUNT { let ptext = format!("Relax in the safety of your own delusions. {i}"); - let ctext = encrypt(&mut bob_store, &alice_address, &ptext).await?; + let ctext = encrypt(&mut bob_store, &alice_address, &bob_address, &ptext).await?; bob_messages.push((ptext, ctext)); } @@ -2599,7 +2723,7 @@ async fn run_interaction( ) -> TestResult { let alice_ptext = "It's rabbit season"; - let alice_message = encrypt(alice_store, bob_address, alice_ptext).await?; + let alice_message = encrypt(alice_store, bob_address, alice_address, alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8(decrypt(bob_store, alice_address, bob_address, &alice_message).await?) @@ -2609,7 +2733,7 @@ async fn run_interaction( let bob_ptext = "It's duck season"; - let bob_message = encrypt(bob_store, alice_address, bob_ptext).await?; + let bob_message = encrypt(bob_store, alice_address, bob_address, bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8(decrypt(alice_store, bob_address, alice_address, &bob_message).await?) @@ -2619,7 +2743,7 @@ async fn run_interaction( for i in 0..10 { let alice_ptext = format!("A->B message {i}"); - let alice_message = encrypt(alice_store, bob_address, &alice_ptext).await?; + let alice_message = encrypt(alice_store, bob_address, alice_address, &alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8( @@ -2632,7 +2756,7 @@ async fn run_interaction( for i in 0..10 { let bob_ptext = format!("B->A message {i}"); - let bob_message = encrypt(bob_store, alice_address, &bob_ptext).await?; + let bob_message = encrypt(bob_store, alice_address, bob_address, &bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8( @@ -2647,13 +2771,13 @@ async fn run_interaction( for i in 0..10 { let alice_ptext = format!("A->B OOO message {i}"); - let alice_message = encrypt(alice_store, bob_address, &alice_ptext).await?; + let alice_message = encrypt(alice_store, bob_address, alice_address, &alice_ptext).await?; alice_ooo_messages.push((alice_ptext, alice_message)); } for i in 0..10 { let alice_ptext = format!("A->B post-OOO message {i}"); - let alice_message = encrypt(alice_store, bob_address, &alice_ptext).await?; + let alice_message = encrypt(alice_store, bob_address, alice_address, &alice_ptext).await?; assert_eq!(alice_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8( @@ -2666,7 +2790,7 @@ async fn run_interaction( for i in 0..10 { let bob_ptext = format!("B->A message post-OOO {i}"); - let bob_message = encrypt(bob_store, alice_address, &bob_ptext).await?; + let bob_message = encrypt(bob_store, alice_address, bob_address, &bob_ptext).await?; assert_eq!(bob_message.message_type(), CiphertextMessageType::Whisper); assert_eq!( String::from_utf8( @@ -2739,11 +2863,13 @@ fn test_signedprekey_not_saved() -> TestResult { let original_message = "L'homme est condamné à être libre"; // We encrypt a first message - let outgoing_message = encrypt(alice_store, &bob_address, original_message).await?; + let outgoing_message = + encrypt(alice_store, &bob_address, &alice_address, original_message).await?; // We encrypt a second message let original_message2 = "L'homme est condamné à nouveau à être libre"; - let outgoing_message2 = encrypt(alice_store, &bob_address, original_message2).await?; + let outgoing_message2 = + encrypt(alice_store, &bob_address, &alice_address, original_message2).await?; assert_eq!( outgoing_message.message_type(), @@ -2995,14 +3121,14 @@ fn test_longer_sessions() -> TestResult { log::debug!("Send message to Alice"); to_alice.push_back(( false, - encrypt(bob_store, &alice_address, "wheee1").await?, + encrypt(bob_store, &alice_address, &bob_address, "wheee1").await?, )); } LongerSessionActions::BobSend => { log::debug!("Send message to Bob"); to_bob.push_back(( false, - encrypt(alice_store, &bob_address, "wheee2").await?, + encrypt(alice_store, &bob_address, &alice_address, "wheee2").await?, )); } LongerSessionActions::AliceRecv => match to_alice.pop_front() { @@ -3099,7 +3225,13 @@ fn test_duplicate_message_error_returned() -> TestResult { ) .await?; - let msg = encrypt(alice_store, &bob_address, "this_will_be_a_dup").await?; + let msg = encrypt( + alice_store, + &bob_address, + &alice_address, + "this_will_be_a_dup", + ) + .await?; decrypt(bob_store, &alice_address, &bob_address, &msg).await?; let err = decrypt(bob_store, &alice_address, &bob_address, &msg) .await @@ -3146,15 +3278,15 @@ fn test_pqr_state_and_message_contents_nonempty() -> TestResult { ) .await?; - let msg = encrypt(alice_store, &bob_address, "msg1").await?; + let msg = encrypt(alice_store, &bob_address, &alice_address, "msg1").await?; assert_matches!(&msg, CiphertextMessage::PreKeySignalMessage(m) if !m.message().pq_ratchet().is_empty()); decrypt(bob_store, &alice_address, &bob_address, &msg).await?; - let msg = encrypt(bob_store, &alice_address, "msg2").await?; + let msg = encrypt(bob_store, &alice_address, &bob_address, "msg2").await?; assert_matches!(&msg, CiphertextMessage::SignalMessage(m) if !m.pq_ratchet().is_empty()); decrypt(alice_store, &bob_address, &alice_address, &msg).await?; - let msg = encrypt(alice_store, &bob_address, "msg3").await?; + let msg = encrypt(alice_store, &bob_address, &alice_address, "msg3").await?; assert_matches!(&msg, CiphertextMessage::SignalMessage(m) if !m.pq_ratchet().is_empty()); assert!(!alice_store @@ -3211,9 +3343,10 @@ fn x3dh_prekey_rejected_as_invalid_message_specifically() { .await .expect("valid"); - let pre_key_message = support::encrypt(&mut alice_store, &bob_address, "bad") - .await - .expect("valid"); + let pre_key_message = + support::encrypt(&mut alice_store, &bob_address, &alice_address, "bad") + .await + .expect("valid"); let mut bob_one_off_store = bob_store_builder.store.clone(); _ = support::decrypt( @@ -3289,9 +3422,10 @@ fn x3dh_established_session_is_or_is_not_usable() { .await .expect("valid"); - let pre_key_message = support::encrypt(&mut alice_store, &bob_address, "bad") - .await - .expect("valid"); + let pre_key_message = + support::encrypt(&mut alice_store, &bob_address, &alice_address, "bad") + .await + .expect("valid"); let bob_store = &mut bob_store_builder.store; _ = support::decrypt(bob_store, &alice_address, &bob_address, &pre_key_message) @@ -3391,9 +3525,10 @@ fn prekey_message_sent_from_different_user_is_rejected() { .await .expect("valid"); - let pre_key_message = support::encrypt(&mut alice_store, &bob_address, "bad") - .await - .expect("valid"); + let pre_key_message = + support::encrypt(&mut alice_store, &bob_address, &alice_address, "bad") + .await + .expect("valid"); let bob_store = &mut bob_store_builder.store; _ = support::decrypt(bob_store, &alice_address, &bob_address, &pre_key_message) @@ -3463,9 +3598,10 @@ fn prekey_message_rejects_wrong_local_recipient_address() { .await .expect("valid"); - let pre_key_message = support::encrypt(&mut alice_store, &bob_address, "hi bob") - .await - .expect("valid"); + let pre_key_message = + support::encrypt(&mut alice_store, &bob_address, &alice_address, "hi bob") + .await + .expect("valid"); let err = support::decrypt( &mut bob_store_builder.store, diff --git a/rust/protocol/tests/support/mod.rs b/rust/protocol/tests/support/mod.rs index 4bab798fc..864793c27 100644 --- a/rust/protocol/tests/support/mod.rs +++ b/rust/protocol/tests/support/mod.rs @@ -31,12 +31,14 @@ pub fn test_in_memory_protocol_store() -> Result Result { let mut csprng = OsRng.unwrap_err(); message_encrypt( msg.as_bytes(), remote_address, + local_address, &mut store.session_store, &mut store.identity_store, SystemTime::now(), diff --git a/swift/Sources/LibSignalClient/Protocol.swift b/swift/Sources/LibSignalClient/Protocol.swift index a513b3186..398da324e 100644 --- a/swift/Sources/LibSignalClient/Protocol.swift +++ b/swift/Sources/LibSignalClient/Protocol.swift @@ -9,12 +9,16 @@ import SignalFfi public func signalEncrypt( message: Bytes, for address: ProtocolAddress, + localAddress: ProtocolAddress, sessionStore: SessionStore, identityStore: IdentityKeyStore, now: Date = Date(), context: StoreContext ) throws -> CiphertextMessage { - return try withAllBorrowed(address, .bytes(message)) { addressHandle, messageBuffer in + return try withAllBorrowed(address, localAddress, .bytes(message)) { + addressHandle, + localAddressHandle, + messageBuffer in try withSessionStore(sessionStore, context) { ffiSessionStore in try withIdentityKeyStore(identityStore, context) { ffiIdentityStore in try invokeFnReturningNativeHandle { @@ -22,6 +26,7 @@ public func signalEncrypt( $0, messageBuffer, addressHandle.const(), + localAddressHandle.const(), ffiSessionStore, ffiIdentityStore, UInt64(now.timeIntervalSince1970 * 1000) diff --git a/swift/Sources/SignalFfi/signal_ffi.h b/swift/Sources/SignalFfi/signal_ffi.h index ee071f242..0b2234f9d 100644 --- a/swift/Sources/SignalFfi/signal_ffi.h +++ b/swift/Sources/SignalFfi/signal_ffi.h @@ -1932,7 +1932,7 @@ SignalFfiError *signal_device_transfer_generate_private_key(SignalOwnedBuffer *o SignalFfiError *signal_device_transfer_generate_private_key_with_format(SignalOwnedBuffer *out, uint8_t key_format); -SignalFfiError *signal_encrypt_message(SignalMutPointerCiphertextMessage *out, SignalBorrowedBuffer ptext, SignalConstPointerProtocolAddress protocol_address, SignalConstPointerFfiSessionStoreStruct session_store, SignalConstPointerFfiIdentityKeyStoreStruct identity_key_store, uint64_t now); +SignalFfiError *signal_encrypt_message(SignalMutPointerCiphertextMessage *out, SignalBorrowedBuffer ptext, SignalConstPointerProtocolAddress protocol_address, SignalConstPointerProtocolAddress local_address, SignalConstPointerFfiSessionStoreStruct session_store, SignalConstPointerFfiIdentityKeyStoreStruct identity_key_store, uint64_t now); void signal_error_free(SignalFfiError *err); diff --git a/swift/Tests/LibSignalClientTests/SessionTests.swift b/swift/Tests/LibSignalClientTests/SessionTests.swift index 0f192738e..f856abcc5 100644 --- a/swift/Tests/LibSignalClientTests/SessionTests.swift +++ b/swift/Tests/LibSignalClientTests/SessionTests.swift @@ -31,6 +31,7 @@ class SessionTests: TestCaseBase { let ctext_a = try! signalEncrypt( message: ptext_a, for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, context: NullContext() @@ -60,6 +61,7 @@ class SessionTests: TestCaseBase { let ctext2_b = try! signalEncrypt( message: ptext2_b, for: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, context: NullContext() @@ -99,6 +101,7 @@ class SessionTests: TestCaseBase { let ctext_a = try! signalEncrypt( message: ptext_a, for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, context: NullContext() @@ -132,6 +135,7 @@ class SessionTests: TestCaseBase { func testExpiresUnacknowledgedSessions() { let bob_address = try! ProtocolAddress(name: "+14151111112", deviceId: 1) + let alice_address = try! ProtocolAddress(name: "+14151111111", deviceId: 1) let alice_store = InMemorySignalProtocolStore() let bob_store = InMemorySignalProtocolStore() @@ -190,6 +194,7 @@ class SessionTests: TestCaseBase { let ctext_a = try! signalEncrypt( message: ptext_a, for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, now: Date(timeIntervalSinceReferenceDate: 0), @@ -206,6 +211,7 @@ class SessionTests: TestCaseBase { try signalEncrypt( message: ptext_a, for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, now: Date(timeIntervalSinceReferenceDate: 60 * 60 * 24 * 90), @@ -290,6 +296,7 @@ class SessionTests: TestCaseBase { let ctext_a = try! signalEncrypt( message: ptext_a, for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, context: NullContext() @@ -372,6 +379,10 @@ class SessionTests: TestCaseBase { let ciphertextMessage = try signalEncrypt( message: message, for: address, + localAddress: try ProtocolAddress( + name: senderCert.sender.uuidString, + deviceId: UInt32(senderCert.sender.deviceId) + ), sessionStore: sessionStore, identityStore: identityStore, context: context @@ -437,6 +448,7 @@ class SessionTests: TestCaseBase { let innerMessage = try signalEncrypt( message: [], for: bob_address, + localAddress: alice_address, sessionStore: alice_store, identityStore: alice_store, context: NullContext() @@ -767,6 +779,7 @@ class SessionTests: TestCaseBase { let bob_first_message = try signalEncrypt( message: Array("swim camp".utf8), for: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, context: NullContext() @@ -786,6 +799,7 @@ class SessionTests: TestCaseBase { let bob_message = try signalEncrypt( message: Array("space camp".utf8), for: alice_address, + localAddress: bob_address, sessionStore: bob_store, identityStore: bob_store, context: NullContext()