bridge: Remove SignalFfiError enum

Now there's a trait, FfiError, which handles conversion to a string
and numeric code, and a helper struct SignalFfiError that mostly just
wraps `Box<dyn FfiError>`. This makes it easier to add new errors --
they only need to be added in two places (a trait impl and possibly
new error codes) instead of three.
This commit is contained in:
Jordan Rose
2024-06-03 19:50:11 -07:00
parent 4e2a7de574
commit 86c07ee86a
8 changed files with 610 additions and 661 deletions

View File

@@ -14,9 +14,6 @@ use std::ffi::{c_char, c_uchar, c_uint, CString};
use std::panic::AssertUnwindSafe;
pub mod logging;
mod util;
use crate::util::*;
#[no_mangle]
pub unsafe extern "C" fn signal_print_ptr(p: *const std::ffi::c_void) {
@@ -69,11 +66,8 @@ pub unsafe extern "C" fn signal_error_get_message(
out: *mut *const c_char,
) -> *mut SignalFfiError {
let result = (|| {
if err.is_null() {
return Err(SignalFfiError::NullPointer);
}
let msg = format!("{}", *err);
write_result_to(out, msg)
let err = err.as_ref().ok_or(NullPointerError)?;
write_result_to(out, err.to_string())
})();
match result {
@@ -89,18 +83,17 @@ pub unsafe extern "C" fn signal_error_get_address(
) -> *mut SignalFfiError {
let err = AssertUnwindSafe(err);
run_ffi_safe(|| {
let err = err.as_ref().ok_or(SignalFfiError::NullPointer)?;
match err {
SignalFfiError::Signal(SignalProtocolError::InvalidRegistrationId(addr, _value)) => {
let err = err.as_ref().ok_or(NullPointerError)?;
match err.downcast_ref::<SignalProtocolError>() {
Some(SignalProtocolError::InvalidRegistrationId(addr, _value)) => {
write_result_to(out, addr.clone())?;
}
_ => {
return Err(SignalFfiError::Signal(
SignalProtocolError::InvalidArgument(format!(
"cannot get address from error ({})",
err
)),
));
return Err(SignalProtocolError::InvalidArgument(format!(
"cannot get address from error ({})",
err
))
.into());
}
}
Ok(())
@@ -114,20 +107,17 @@ pub unsafe extern "C" fn signal_error_get_uuid(
) -> *mut SignalFfiError {
let err = AssertUnwindSafe(err);
run_ffi_safe(|| {
let err = err.as_ref().ok_or(SignalFfiError::NullPointer)?;
match err {
SignalFfiError::Signal(SignalProtocolError::InvalidSenderKeySession {
distribution_id,
}) => {
let err = err.as_ref().ok_or(NullPointerError)?;
match err.downcast_ref::<SignalProtocolError>() {
Some(SignalProtocolError::InvalidSenderKeySession { distribution_id }) => {
write_result_to(out, *distribution_id.as_bytes())?;
}
_ => {
return Err(SignalFfiError::Signal(
SignalProtocolError::InvalidArgument(format!(
"cannot get address from error ({})",
err
)),
));
return Err(SignalProtocolError::InvalidArgument(format!(
"cannot get UUID from error ({})",
err
))
.into());
}
}
Ok(())
@@ -137,10 +127,7 @@ pub unsafe extern "C" fn signal_error_get_uuid(
#[no_mangle]
pub unsafe extern "C" fn signal_error_get_type(err: *const SignalFfiError) -> u32 {
match err.as_ref() {
Some(err) => {
let code: SignalErrorCode = err.into();
code as u32
}
Some(err) => err.code() as u32,
None => 0,
}
}
@@ -152,16 +139,16 @@ pub unsafe extern "C" fn signal_error_get_retry_after_seconds(
) -> *mut SignalFfiError {
let err = AssertUnwindSafe(err);
run_ffi_safe(|| {
let err = err.as_ref().ok_or(SignalFfiError::NullPointer)?;
match err {
SignalFfiError::RateLimited {
let err = err.as_ref().ok_or(NullPointerError)?;
match err.downcast_ref::<libsignal_net::cdsi::LookupError>() {
Some(libsignal_net::cdsi::LookupError::RateLimited {
retry_after_seconds,
} => write_result_to(out, *retry_after_seconds),
err => Err(SignalFfiError::Signal(
SignalProtocolError::InvalidArgument(format!(
"cannot get retry_after_seconds from error ({err})"
)),
)),
}) => write_result_to(out, *retry_after_seconds),
_ => Err(SignalProtocolError::InvalidArgument(format!(
"cannot get retry_after_seconds from error ({})",
err
))
.into()),
}
})
}
@@ -173,16 +160,16 @@ pub unsafe extern "C" fn signal_error_get_tries_remaining(
) -> *mut SignalFfiError {
let err = AssertUnwindSafe(err);
run_ffi_safe(|| {
let err = err.as_ref().ok_or(SignalFfiError::NullPointer)?;
match err {
SignalFfiError::Svr(libsignal_net::svr3::Error::RestoreFailed(tries_remaining)) => {
let err = err.as_ref().ok_or(NullPointerError)?;
match err.downcast_ref::<libsignal_net::svr3::Error>() {
Some(libsignal_net::svr3::Error::RestoreFailed(tries_remaining)) => {
write_result_to(out, *tries_remaining)
}
err => Err(SignalFfiError::Signal(
SignalProtocolError::InvalidArgument(format!(
"cannot get tries_remaining from error ({err})"
)),
)),
_ => Err(SignalProtocolError::InvalidArgument(format!(
"cannot get tries_remaining from error ({})",
err
))
.into()),
}
})
}
@@ -230,15 +217,13 @@ pub unsafe extern "C" fn signal_sealed_session_cipher_decrypt(
let mut kyber_pre_key_store = InMemKyberPreKeyStore::new();
let ctext = ctext.as_slice()?;
let trust_root = native_handle_cast::<PublicKey>(trust_root)?;
let mut identity_store = identity_store.as_ref().ok_or(SignalFfiError::NullPointer)?;
let mut session_store = session_store.as_ref().ok_or(SignalFfiError::NullPointer)?;
let mut prekey_store = prekey_store.as_ref().ok_or(SignalFfiError::NullPointer)?;
let signed_prekey_store = signed_prekey_store
.as_ref()
.ok_or(SignalFfiError::NullPointer)?;
let mut identity_store = identity_store.as_ref().ok_or(NullPointerError)?;
let mut session_store = session_store.as_ref().ok_or(NullPointerError)?;
let mut prekey_store = prekey_store.as_ref().ok_or(NullPointerError)?;
let signed_prekey_store = signed_prekey_store.as_ref().ok_or(NullPointerError)?;
let local_e164 = Option::convert_from(local_e164)?;
let local_uuid = Option::convert_from(local_uuid)?.ok_or(SignalFfiError::NullPointer)?;
let local_uuid = Option::convert_from(local_uuid)?.ok_or(NullPointerError)?;
let decrypted = sealed_sender_decrypt(
ctext,