diff --git a/AK/Coroutine.h b/AK/Coroutine.h index a760ba7fa52..f1cf6019187 100644 --- a/AK/Coroutine.h +++ b/AK/Coroutine.h @@ -31,6 +31,12 @@ struct SuspendNever { void await_resume() const noexcept { } }; +struct SuspendAlways { + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<>) const noexcept { } + void await_resume() const noexcept { } +}; + struct SymmetricControlTransfer { SymmetricControlTransfer(std::coroutine_handle<> handle) : m_handle(handle ? handle : std::noop_coroutine()) diff --git a/AK/Generator.h b/AK/Generator.h new file mode 100644 index 00000000000..9ec1c7b746f --- /dev/null +++ b/AK/Generator.h @@ -0,0 +1,213 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include + +namespace AK { + +namespace Detail { +class YieldAwaiter { +public: + YieldAwaiter(std::coroutine_handle<> control_transfer, std::coroutine_handle<>& awaiter) + : m_control_transfer(control_transfer) + , m_awaiter(awaiter) + { + } + + bool await_ready() const { return false; } + + auto await_suspend(std::coroutine_handle<> handle) + { + m_awaiter = handle; + return m_control_transfer; + } + + void await_resume() { } + +private: + std::coroutine_handle<> m_control_transfer; + std::coroutine_handle<>& m_awaiter; +}; +} + +template +class [[nodiscard]] Generator { + struct GeneratorPromiseType; + + AK_MAKE_NONCOPYABLE(Generator); + +public: + using YieldType = Y; + using ReturnType = R; + using promise_type = GeneratorPromiseType; + + ~Generator() + { + destroy_stored_object(); + if (m_handle) + m_handle.destroy(); + } + + Generator(Generator&& other) + { + m_handle = AK::exchange(other.m_handle, {}); + m_read_returned_object = exchange(other.m_read_returned_object, false); + + m_currently_stored_type = other.m_currently_stored_type; + if (m_currently_stored_type == CurrentlyStoredType::Yield) { + new (m_data) YieldType(move(*reinterpret_cast(other.m_data))); + } else if (m_currently_stored_type == CurrentlyStoredType::Return) { + new (m_data) ReturnType(move(*reinterpret_cast(other.m_data))); + } + other.destroy_stored_object(); + + if (m_handle) + m_handle.promise().m_coroutine = this; + } + + Generator& operator=(Generator&& other) + { + if (this != &other) { + this->~Generator(); + new (this) Generator(move(other)); + } + return *this; + } + + bool is_done() const { return !m_handle || m_handle.done(); } + + void destroy() + { + VERIFY(m_handle && !m_handle.promise().m_awaiter); + destroy_stored_object(); + m_handle.destroy(); + m_handle = {}; + } + + Coroutine> next() + { + if (!is_done()) { + VERIFY(m_currently_stored_type != CurrentlyStoredType::Return); + co_await Detail::YieldAwaiter { m_handle, m_handle.promise().m_awaiter }; + if (m_handle) + m_handle.promise().m_awaiter = {}; + } + + if (is_done()) { + VERIFY(m_currently_stored_type == CurrentlyStoredType::Return && !m_read_returned_object); + m_read_returned_object = true; + co_return move(*reinterpret_cast(m_data)); + } else { + VERIFY(m_currently_stored_type == CurrentlyStoredType::Yield); + co_return move(*reinterpret_cast(m_data)); + } + } + +private: + template + friend struct Detail::TryAwaiter; + + struct GeneratorPromiseType { + Generator get_return_object() + { + return { std::coroutine_handle::from_promise(*this) }; + } + + Detail::SuspendAlways initial_suspend() { return {}; } + + Detail::SymmetricControlTransfer final_suspend() noexcept + { + VERIFY(m_awaiter); + return { m_awaiter }; + } + + template + requires requires { { T(forward(declval())) }; } + void return_value(U&& returned_object) + { + m_coroutine->place_returned_object(forward(returned_object)); + } + + void return_value(ReturnType&& returned_object) + { + m_coroutine->place_returned_object(move(returned_object)); + } + + Detail::SymmetricControlTransfer yield_value(YieldType&& yield_value) + { + m_coroutine->place_yield_object(move(yield_value)); + VERIFY(m_awaiter); + return { m_awaiter }; + } + + std::coroutine_handle<> m_awaiter; + Generator* m_coroutine { nullptr }; // Must be named `m_coroutine` for CO_TRY to work + }; + + Generator(std::coroutine_handle&& handle) + : m_handle(move(handle)) + { + m_handle.promise().m_coroutine = this; + } + + void destroy_stored_object() + { + switch (m_currently_stored_type) { + case CurrentlyStoredType::Empty: + break; + case CurrentlyStoredType::Yield: + reinterpret_cast(m_data)->~YieldType(); + break; + case CurrentlyStoredType::Return: + reinterpret_cast(m_data)->~ReturnType(); + break; + } + m_currently_stored_type = CurrentlyStoredType::Empty; + } + + template + YieldType* place_yield_object(Args&&... args) + { + destroy_stored_object(); + m_currently_stored_type = CurrentlyStoredType::Yield; + return new (m_data) YieldType(forward(args)...); + } + + template + ReturnType* place_returned_object(Args&&... args) + { + destroy_stored_object(); + m_currently_stored_type = CurrentlyStoredType::Return; + return new (m_data) ReturnType(forward(args)...); + } + + ReturnType* return_value() // Must be defined for CO_TRY. + { + destroy_stored_object(); + m_currently_stored_type = CurrentlyStoredType::Return; + return reinterpret_cast(m_data); + } + + std::coroutine_handle m_handle; + + enum class CurrentlyStoredType { + Empty, + Yield, + Return, + } m_currently_stored_type + = CurrentlyStoredType::Empty; + bool m_read_returned_object { false }; + alignas(max(alignof(YieldType), alignof(ReturnType))) u8 m_data[max(sizeof(YieldType), sizeof(ReturnType))]; +}; + +} + +#ifdef USING_AK_GLOBALLY +using AK::Generator; +#endif diff --git a/Tests/AK/CMakeLists.txt b/Tests/AK/CMakeLists.txt index 276924d7779..73091c14170 100644 --- a/Tests/AK/CMakeLists.txt +++ b/Tests/AK/CMakeLists.txt @@ -35,6 +35,7 @@ set(AK_TEST_SOURCES TestFlyString.cpp TestFormat.cpp TestFuzzyMatch.cpp + TestGeneratorAK.cpp TestGenericLexer.cpp TestHashFunctions.cpp TestHashMap.cpp diff --git a/Tests/AK/TestGeneratorAK.cpp b/Tests/AK/TestGeneratorAK.cpp new file mode 100644 index 00000000000..eba3d41f859 --- /dev/null +++ b/Tests/AK/TestGeneratorAK.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include + +namespace { + +Generator generate_sync(Vector& order) +{ + ScopeGuard guard = [&] { + order.append(7); + }; + + order.append(2); + co_yield 1; + order.append(4); + co_yield 2; + order.append(6); + co_return {}; +} + +} + +ASYNC_TEST_CASE(sync_order) +{ + Vector order; + + auto gen = generate_sync(order); + EXPECT(!gen.is_done()); + + order.append(1); + + auto result1 = gen.next(); + order.append(3); + EXPECT(result1.await_ready()); + EXPECT_EQ(result1.await_resume(), 1); + + auto result2 = gen.next(); + order.append(5); + EXPECT(result2.await_ready()); + EXPECT_EQ(result2.await_resume(), 2); + + auto end = gen.next(); + order.append(8); + EXPECT(end.await_ready()); + EXPECT_EQ(end.await_resume(), Empty {}); + + EXPECT_EQ(order, (Vector { 1, 2, 3, 4, 5, 6, 7, 8 })); + co_return; +}