Invoke native methods instead of manipulating protobufs within Java

This commit is contained in:
Jack Lloyd
2020-12-08 18:06:28 -05:00
parent a32efa2d24
commit f8182af008
6 changed files with 189 additions and 122 deletions

View File

@@ -203,8 +203,28 @@ public final class Native {
public static native byte[] SessionCipher_DecryptSignalMessage(long message, long protocolAddress, SessionStore sessionStore, IdentityKeyStore identityKeyStore);
public static native CiphertextMessage SessionCipher_EncryptMessage(byte[] message, long protocolAddress, SessionStore sessionStore, IdentityKeyStore identityKeyStore);
public static native void SessionRecord_ArchiveCurrentState(long handle);
public static native long SessionRecord_Deserialize(byte[] data);
public static native void SessionRecord_Destroy(long handle);
public static native long SessionRecord_FromSessionState(long sessionState);
public static native long SessionRecord_GetSessionState(long sessionRecord);
public static native long SessionRecord_NewFresh();
public static native byte[] SessionRecord_Serialize(long handle);
public static native long SessionState_Deserialize(byte[] data);
public static native void SessionState_Destroy(long handle);
public static native byte[] SessionState_GetAliceBaseKey(long handle);
public static native byte[] SessionState_GetLocalIdentityKeyPublic(long handle);
public static native int SessionState_GetLocalRegistrationId(long handle);
public static native byte[] SessionState_GetReceiverChainKeyValue(long sessionRecord, long key);
public static native byte[] SessionState_GetRemoteIdentityKeyPublic(long handle);
public static native int SessionState_GetRemoteRegistrationId(long handle);
public static native byte[] SessionState_GetSenderChainKeyValue(long handle);
public static native int SessionState_GetSessionVersion(long handle);
public static native boolean SessionState_HasSenderChain(long handle);
public static native byte[] SessionState_InitializeAliceSession(long identityKeyPrivate, long identityKeyPublic, long basePrivate, long basePublic, long theirIdentityKey, long theirSignedPrekey, long theirRatchetKey);
public static native byte[] SessionState_InitializeBobSession(long identityKeyPrivate, long identityKeyPublic, long signedPrekeyPrivate, long signedPrekeyPublic, long ephPrivate, long ephPublic, long theirIdentityKey, long theirBaseKey);
public static native byte[] SessionState_Serialized(long handle);
public static native long SignalMessage_Deserialize(byte[] data);
public static native void SignalMessage_Destroy(long handle);

View File

@@ -5,13 +5,8 @@
*/
package org.whispersystems.libsignal.state;
import org.signal.client.internal.Native;
import java.io.IOException;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import static org.whispersystems.libsignal.state.StorageProtos.RecordStructure;
import static org.whispersystems.libsignal.state.StorageProtos.SessionStructure;
/**
* A SessionRecord encapsulates the state of an ongoing session.
@@ -20,33 +15,27 @@ import static org.whispersystems.libsignal.state.StorageProtos.SessionStructure;
*/
public class SessionRecord {
private static final int ARCHIVED_STATES_MAX_LENGTH = 40;
long handle;
private SessionState sessionState = new SessionState();
private LinkedList<SessionState> previousStates = new LinkedList<>();
private boolean fresh = false;
@Override
protected void finalize() {
Native.SessionRecord_Destroy(this.handle);
}
public SessionRecord() {
this.fresh = true;
this.handle = Native.SessionRecord_NewFresh();
}
public SessionRecord(SessionState sessionState) {
this.sessionState = sessionState;
this.fresh = false;
this.handle = Native.SessionRecord_FromSessionState(sessionState.nativeHandle());
}
public SessionRecord(byte[] serialized) throws IOException {
RecordStructure record = RecordStructure.parseFrom(serialized);
this.sessionState = new SessionState(record.getCurrentSession());
this.fresh = false;
for (SessionStructure previousStructure : record.getPreviousSessionsList()) {
previousStates.add(new SessionState(previousStructure));
}
this.handle = Native.SessionRecord_Deserialize(serialized);
}
public SessionState getSessionState() {
return sessionState;
return new SessionState(Native.SessionRecord_GetSessionState(this.handle));
}
/**
@@ -55,34 +44,14 @@ public class SessionRecord {
* with a fresh reset instance.
*/
public void archiveCurrentState() {
promoteState(new SessionState());
}
private void promoteState(SessionState promotedState) {
this.previousStates.addFirst(sessionState);
this.sessionState = promotedState;
if (previousStates.size() > ARCHIVED_STATES_MAX_LENGTH) {
previousStates.removeLast();
}
Native.SessionRecord_ArchiveCurrentState(this.handle);
}
/**
* @return a serialized version of the current SessionRecord.
*/
public byte[] serialize() {
List<SessionStructure> previousStructures = new LinkedList<>();
for (SessionState previousState : previousStates) {
previousStructures.add(previousState.getStructure());
}
RecordStructure record = RecordStructure.newBuilder()
.setCurrentSession(sessionState.getStructure())
.addAllPreviousSessions(previousStructures)
.build();
return record.toByteArray();
return Native.SessionRecord_Serialize(this.handle);
}
}

View File

@@ -7,30 +7,21 @@
package org.whispersystems.libsignal.state;
import org.signal.client.internal.Native;
import org.whispersystems.libsignal.IdentityKey;
import org.whispersystems.libsignal.IdentityKeyPair;
import org.whispersystems.libsignal.InvalidKeyException;
import org.whispersystems.libsignal.ecc.Curve;
import org.whispersystems.libsignal.ecc.ECKeyPair;
import org.whispersystems.libsignal.ecc.ECPrivateKey;
import org.whispersystems.libsignal.ecc.ECPublicKey;
import org.whispersystems.libsignal.logging.Log;
import org.whispersystems.libsignal.state.StorageProtos.SessionStructure.Chain;
import org.whispersystems.libsignal.state.StorageProtos.SessionStructure.PendingKeyExchange;
import org.whispersystems.libsignal.state.StorageProtos.SessionStructure.PendingPreKey;
import org.whispersystems.libsignal.util.Pair;
import org.whispersystems.libsignal.util.guava.Optional;
import org.whispersystems.libsignal.InvalidKeyException;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import static org.whispersystems.libsignal.state.StorageProtos.SessionStructure;
public class SessionState {
private SessionStructure sessionStructure;
private long handle;
@Override
protected void finalize() {
Native.SessionState_Destroy(this.handle);
}
static public SessionState initializeAliceSession(IdentityKeyPair identityKey,
ECKeyPair baseKey,
@@ -70,109 +61,80 @@ public class SessionState {
}
public SessionState(byte[] serialized) throws IOException {
this.sessionStructure = SessionStructure.parseFrom(serialized);
}
private static final int MAX_MESSAGE_KEYS = 2000;
public SessionState() {
this.sessionStructure = SessionStructure.newBuilder().build();
this.handle = Native.SessionState_Deserialize(serialized);
}
public SessionState(SessionStructure sessionStructure) {
this.sessionStructure = sessionStructure;
this.handle = Native.SessionState_Deserialize(sessionStructure.toByteArray());
}
SessionState(long handle) {
this.handle = handle;
}
// Remove this:
SessionState(SessionState copy) {
this.sessionStructure = copy.sessionStructure.toBuilder().build();
}
SessionStructure getStructure() {
return sessionStructure;
this.handle = copy.handle;
}
public byte[] getAliceBaseKey() {
return this.sessionStructure.getAliceBaseKey().toByteArray();
return Native.SessionState_GetAliceBaseKey(this.handle);
}
public int getSessionVersion() {
int sessionVersion = this.sessionStructure.getSessionVersion();
if (sessionVersion == 0) return 2;
else return sessionVersion;
return Native.SessionState_GetSessionVersion(this.handle);
}
public IdentityKey getRemoteIdentityKey() {
try {
if (!this.sessionStructure.hasRemoteIdentityPublic()) {
return null;
}
byte[] keyBytes = Native.SessionState_GetRemoteIdentityKeyPublic(this.handle);
return new IdentityKey(this.sessionStructure.getRemoteIdentityPublic().toByteArray(), 0);
} catch (InvalidKeyException e) {
Log.w("SessionRecordV2", e);
if (keyBytes == null){
return null;
}
try {
return new IdentityKey(keyBytes);
}
catch(InvalidKeyException e) {
throw new AssertionError(e);
}
}
public IdentityKey getLocalIdentityKey() {
byte[] keyBytes = Native.SessionState_GetLocalIdentityKeyPublic(this.handle);
try {
return new IdentityKey(this.sessionStructure.getLocalIdentityPublic().toByteArray(), 0);
} catch (InvalidKeyException e) {
return new IdentityKey(keyBytes);
}
catch(InvalidKeyException e) {
throw new AssertionError(e);
}
}
public boolean hasSenderChain() {
return sessionStructure.hasSenderChain();
return Native.SessionState_HasSenderChain(this.handle);
}
private Pair<Chain,Integer> getReceiverChain(ECPublicKey senderEphemeral) {
List<Chain> receiverChains = sessionStructure.getReceiverChainsList();
int index = 0;
for (Chain receiverChain : receiverChains) {
try {
ECPublicKey chainSenderRatchetKey = Curve.decodePoint(receiverChain.getSenderRatchetKey().toByteArray(), 0);
if (chainSenderRatchetKey.equals(senderEphemeral)) {
return new Pair<>(receiverChain,index);
}
} catch (InvalidKeyException e) {
Log.w("SessionRecordV2", e);
}
index++;
}
return null;
}
public byte[] getReceiverChainKeyValue(ECPublicKey senderEphemeral) {
Pair<Chain,Integer> receiverChainAndIndex = getReceiverChain(senderEphemeral);
Chain receiverChain = receiverChainAndIndex.first();
if (receiverChain == null) {
return null;
} else {
return receiverChain.getChainKey().getKey().toByteArray();
}
public byte[] getReceiverChainKeyValue(ECPublicKey senderEphemeral) {
return Native.SessionState_GetReceiverChainKeyValue(this.handle, senderEphemeral.nativeHandle());
}
public byte[] getSenderChainKeyValue() {
Chain.ChainKey chainKeyStructure = sessionStructure.getSenderChain().getChainKey();
return chainKeyStructure.getKey().toByteArray();
return Native.SessionState_GetSenderChainKeyValue(this.handle);
}
public int getRemoteRegistrationId() {
return this.sessionStructure.getRemoteRegistrationId();
return Native.SessionState_GetRemoteRegistrationId(this.handle);
}
public int getLocalRegistrationId() {
return this.sessionStructure.getLocalRegistrationId();
return Native.SessionState_GetLocalRegistrationId(this.handle);
}
public byte[] serialize() {
return sessionStructure.toByteArray();
return Native.SessionState_Serialized(this.handle);
}
long nativeHandle() {
return this.handle;
}
}