AK+LibCore: Allow and use Promise<void> and awaiting ErrorOr<void>

This commit is contained in:
Hendiadyoin1
2025-03-06 19:45:04 +01:00
committed by Nico Weber
parent 3afe0ecb8e
commit 3734aef3c1
7 changed files with 167 additions and 25 deletions

View File

@@ -242,6 +242,12 @@ struct TryAwaiter {
}
}
void await_resume()
requires(IsSame<T, ErrorOr<void>>)
{
(void)m_expression->release_value();
}
decltype(auto) await_resume()
{
return m_expression->release_value();
@@ -263,7 +269,7 @@ auto declval_coro_result(T&&) -> T;
// GCC cannot handle CO_TRY(...CO_TRY(...)...), this hack ensures that it always has the right type information available.
// FIXME: Remove this once GCC can correctly infer the result type of `co_await TryAwaiter { ... }`.
# define CO_TRY(expression) static_cast<decltype(AK::Detail::declval_coro_result(expression).release_value())>(co_await ::AK::Detail::TryAwaiter { (expression) })
# define CO_TRY(expression) static_cast<AddRvalueReference<typename RemoveReference<decltype(expression)>::ResultType>>(co_await ::AK::Detail::TryAwaiter { (expression) })
# endif
#elifndef AK_COROUTINE_STATEMENT_EXPRS_BROKEN
# define CO_TRY(expression) \

View File

@@ -156,4 +156,145 @@ private:
Optional<ErrorOr<Result, ErrorType>> m_result_or_rejection;
};
template<typename TError>
class Promise<void, TError> : public EventReceiver {
C_OBJECT(Promise);
public:
using ErrorType = TError;
Function<ErrorOr<void>()> on_resolution;
Function<void(ErrorType&)> on_rejection;
void resolve(void)
{
m_result_or_rejection = Empty {};
if (on_resolution) {
auto handler_result = on_resolution();
possibly_handle_rejection(handler_result);
}
}
void reject(ErrorType&& error)
{
m_result_or_rejection = move(error);
possibly_handle_rejection(*m_result_or_rejection);
}
bool is_rejected()
{
return m_result_or_rejection.has_value() && m_result_or_rejection->is_error();
}
bool is_resolved() const
{
return m_result_or_rejection.has_value() && !m_result_or_rejection->is_error();
}
ErrorOr<void, ErrorType> await()
{
while (!m_result_or_rejection.has_value())
Core::EventLoop::current().pump();
return m_result_or_rejection.release_value();
}
// Converts a Promise<A> to a Promise<B> using a function func: A -> B
template<typename T>
NonnullRefPtr<Promise<T>> map(Function<T(void)> func)
{
NonnullRefPtr<Promise<T>> new_promise = Promise<T>::construct();
if (is_resolved())
new_promise->resolve(func(m_result_or_rejection->value()));
if (is_rejected())
new_promise->reject(m_result_or_rejection->release_error());
on_resolution = [new_promise, func = move(func)](void) -> ErrorOr<void> {
new_promise->resolve(func());
return {};
};
on_rejection = [new_promise](ErrorType& error) {
new_promise->reject(move(error));
};
return new_promise;
}
template<typename T>
NonnullRefPtr<Promise<T>> map(Function<ErrorOr<T>(void)> func)
{
NonnullRefPtr<Promise<T>> new_promise = Promise<T>::construct();
if (is_resolved()) {
auto result = func(m_result_or_rejection->value());
if (result.is_error())
new_promise->reject(result.release_error());
else
new_promise->resolve(result.release_value());
}
if (is_rejected())
new_promise->reject(m_result_or_rejection->release_error());
on_resolution = [new_promise, func = move(func)](void) -> ErrorOr<void> {
auto new_result = func();
if (new_result.is_error())
new_promise->reject(new_result.release_error());
else
new_promise->resolve(new_result.release_value());
return {};
};
on_rejection = [new_promise](ErrorType& error) {
new_promise->reject(move(error));
};
return new_promise;
}
template<CallableAs<void> F>
Promise& when_resolved(F handler)
{
return when_resolved([handler = move(handler)](void) -> ErrorOr<void> {
handler();
return {};
});
}
template<CallableAs<ErrorOr<void>> F>
Promise& when_resolved(F handler)
{
on_resolution = move(handler);
if (is_resolved()) {
auto handler_result = on_resolution();
possibly_handle_rejection(handler_result);
}
return *this;
}
template<CallableAs<void, ErrorType&> F>
Promise& when_rejected(F handler)
{
on_rejection = move(handler);
if (is_rejected())
on_rejection(m_result_or_rejection->error());
return *this;
}
private:
template<typename T>
void possibly_handle_rejection(ErrorOr<T>& result)
{
if (result.is_error() && on_rejection)
on_rejection(result.error());
}
Promise() = default;
Promise(EventReceiver* parent)
: EventReceiver(parent)
{
}
Optional<ErrorOr<void, ErrorType>> m_result_or_rejection;
};
}

View File

@@ -13,7 +13,7 @@ Client::Client(StringView host, u16 port, NonnullOwnPtr<Core::Socket> socket)
: m_host(host)
, m_port(port)
, m_socket(move(socket))
, m_connect_pending(Promise<Empty>::construct())
, m_connect_pending(Promise<void>::construct())
{
setup_callbacks();
}
@@ -84,7 +84,7 @@ ErrorOr<void> Client::on_ready_to_receive()
// Once we get server hello we can start sending.
if (m_connect_pending) {
m_connect_pending->resolve({});
m_connect_pending->resolve();
m_connect_pending.clear();
m_buffer.clear();
return {};

View File

@@ -25,7 +25,7 @@ public:
Client(Client&&);
RefPtr<Promise<Empty>> connection_promise()
RefPtr<Promise<void>> connection_promise()
{
return m_connect_pending;
}
@@ -72,7 +72,7 @@ private:
u16 m_port;
NonnullOwnPtr<Core::Socket> m_socket;
RefPtr<Promise<Empty>> m_connect_pending {};
RefPtr<Promise<void>> m_connect_pending;
int m_current_command = 1;

View File

@@ -56,7 +56,8 @@ struct PromiseAwaiter {
bool await_ready() const { return promise->is_resolved(); }
void await_suspend(std::coroutine_handle<> awaiter)
{
promise->when_resolved([awaiter](auto&) {
// Note: Argument pack as to allow `Promise<void>` which does not pass any arguments to the callback.
promise->when_resolved([awaiter](auto&...) {
Core::deferred_invoke([awaiter] { awaiter.resume(); });
});
promise->when_rejected([awaiter](auto&) {
@@ -76,14 +77,12 @@ struct PromiseAwaiter {
Coroutine<ErrorOr<NonnullOwnPtr<TLSv12>>> TLSv12::async_connect(ByteString const& host, u16 port, Options options)
{
auto promise = Core::Promise<Empty>::construct();
auto promise = Core::Promise<void>::construct();
OwnPtr<Core::Socket> tcp_socket = CO_TRY(co_await Core::TCPSocket::async_connect(host, port));
CO_TRY(tcp_socket->set_blocking(false));
auto tls_socket = make<TLSv12>(move(tcp_socket), move(options));
tls_socket->set_sni(host);
tls_socket->on_connected = [=] {
promise->resolve({});
};
tls_socket->on_connected = [promise] { promise->resolve(); };
tls_socket->on_tls_error = [&tls_socket = *tls_socket, promise](auto alert) {
tls_socket.try_disambiguate_error();
promise->reject(AK::Error::from_string_view(enum_to_string(alert)));
@@ -94,7 +93,7 @@ Coroutine<ErrorOr<NonnullOwnPtr<TLSv12>>> TLSv12::async_connect(ByteString const
tls_socket.on_connected = nullptr;
};
CO_TRY(co_await PromiseAwaiter<Empty> { promise });
CO_TRY(co_await PromiseAwaiter<void> { promise });
tls_socket->m_context.should_expect_successful_read = true;
co_return tls_socket;
@@ -102,15 +101,13 @@ Coroutine<ErrorOr<NonnullOwnPtr<TLSv12>>> TLSv12::async_connect(ByteString const
Coroutine<ErrorOr<NonnullOwnPtr<TLSv12>>> TLSv12::async_connect(ByteString const& host, Core::Socket& underlying_stream, Options options)
{
auto promise = Core::Promise<Empty>::construct();
auto promise = Core::Promise<void>::construct();
CO_TRY(underlying_stream.set_blocking(false));
auto tls_socket = make<TLSv12>(&underlying_stream, move(options));
tls_socket->set_sni(host);
tls_socket->on_connected = [=] {
promise->resolve({});
};
tls_socket->on_tls_error = [&, promise](auto alert) {
tls_socket->try_disambiguate_error();
tls_socket->on_connected = [promise] { promise->resolve(); };
tls_socket->on_tls_error = [&tls_socket = *tls_socket, promise](auto alert) {
tls_socket.try_disambiguate_error();
promise->reject(AK::Error::from_string_view(enum_to_string(alert)));
};
@@ -119,7 +116,7 @@ Coroutine<ErrorOr<NonnullOwnPtr<TLSv12>>> TLSv12::async_connect(ByteString const
tls_socket.on_connected = nullptr;
};
CO_TRY(co_await PromiseAwaiter<Empty> { promise });
CO_TRY(co_await PromiseAwaiter<void> { promise });
tls_socket->m_context.should_expect_successful_read = true;
co_return tls_socket;

View File

@@ -36,18 +36,18 @@ public:
template<typename... PlaceholderValues>
void execute_statement(SQL::StatementID statement_id, OnResult on_result, OnComplete on_complete, OnError on_error, PlaceholderValues&&... placeholder_values)
{
auto sync_promise = Core::Promise<Empty>::construct();
auto sync_promise = Core::Promise<void>::construct();
PendingExecution pending_execution {
.on_result = move(on_result),
.on_complete = [sync_promise, on_complete = move(on_complete)] {
if (on_complete)
on_complete();
sync_promise->resolve({}); },
sync_promise->resolve(); },
.on_error = [sync_promise, on_error = move(on_error)](auto message) {
if (on_error)
on_error(message);
sync_promise->resolve({}); },
sync_promise->resolve(); },
};
Vector<SQL::Value> values { SQL::Value(forward<PlaceholderValues>(placeholder_values))... };

View File

@@ -411,10 +411,8 @@ static ErrorOr<TestResult> run_test(HeadlessWebContentView& view, StringView inp
{
// Clear the current document.
// FIXME: Implement a debug-request to do this more thoroughly.
auto promise = Core::Promise<Empty>::construct();
view.on_load_finish = [&](auto) {
promise->resolve({});
};
auto promise = Core::Promise<void>::construct();
view.on_load_finish = [&](auto const&) { promise->resolve(); };
view.on_text_test_finish = {};
view.on_request_file_picker = [&](auto const& accepted_file_types, auto allow_multiple_files) {