diff --git a/browser_use/eventbus/models.py b/browser_use/eventbus/models.py index 14ae41ef6..bf0e616df 100644 --- a/browser_use/eventbus/models.py +++ b/browser_use/eventbus/models.py @@ -1,6 +1,8 @@ import asyncio +import logging +import warnings from datetime import UTC, datetime -from typing import Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any from uuid import UUID from pydantic import AfterValidator, BaseModel, Field, PrivateAttr @@ -8,6 +10,11 @@ from uuid_extensions import uuid7str from browser_use.utils import get_browser_use_version +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from browser_use.eventbus.service import EventBus + # Constants for validation MAX_STRING_LENGTH = 10000 # 10K chars for most strings MAX_URL_LENGTH = 2000 @@ -22,11 +29,14 @@ class BaseEvent(BaseModel): """ The base model used for all Events that flow through the EventBus system. """ + + model_config = {'arbitrary_types_allowed': True} event_schema: str | None = Field(default=None, description='Event schema version in format ClassName@version', max_length=100) event_type: str event_id: str = Field(default_factory=uuid7str) event_path: list[str] = Field(default_factory=list, description='Path tracking for event routing') + parent_event_id: str | None = Field(default=None, description='ID of the parent event that triggered this event') # Completion tracking fields event_created_at: datetime = Field( @@ -34,34 +44,56 @@ class BaseEvent(BaseModel): ) event_started_at: datetime | None = Field(default=None, description='Timestamp when event was started') event_completed_at: datetime | None = Field(default=None, description='Timestamp when event was completed') + # Store results by handler id to avoid name clashes event_results: dict[str, Any] = Field( - default_factory=dict, exclude=True, description='Handler results {handler_name: result}' + default_factory=dict, exclude=True, description='Handler results {str(id(handler)): result}' ) event_errors: dict[str, str] = Field( - default_factory=dict, exclude=True, description='Handler errors {handler_name: error_str}' + default_factory=dict, exclude=True, description='Handler errors {str(id(handler)): error_str}' ) + # Map handler ids to metadata for result grouping + _handler_metadata: dict[str, dict[str, Any]] = PrivateAttr(default_factory=dict) _event_completed_signal: asyncio.Event | None = PrivateAttr(default=None) + _eventbus: 'EventBus | None' = PrivateAttr(default=None) + results: 'EventResults | None' = Field(default=None, exclude=True) @property def state(self) -> str: return 'completed' if self.event_completed_at else 'started' if self.event_started_at else 'queued' - async def result(self): - """Wait for completion and return self with results""" - if self._event_completed_signal: - await self._event_completed_signal.wait() - return self + def result(self, timeout: float = 30.0) -> 'EventResults': + """Return the EventResults object for accessing handler results""" + if not self.results: + raise RuntimeError("Event must be dispatched through an EventBus to use .result()") + return self.results - def record_results(self, results: dict[str, Any] | None = None, complete: bool = True) -> None: - """Update the event results and optionally mark it as completed""" - self.event_results = { - **(self.event_results or {}), - **(results or {}), + def record_result(self, handler: Any, result: Any, eventbus: 'EventBus | None' = None) -> None: + """Record a handler result with metadata""" + # Special handling for EventResults from forwarded events + if isinstance(result, EventResults): + # Don't record the EventResults object itself + # The forwarded event's results are already recorded in the same event + return + + handler_id = str(id(handler)) + self.event_results[handler_id] = result + self._handler_metadata[handler_id] = { + 'handler': handler, + 'name': handler.__name__, + 'eventbus': eventbus, + 'eventbus_name': eventbus.name if eventbus else None, + 'eventbus_id': str(id(eventbus)) if eventbus else None } - if complete: - self.event_completed_at = datetime.now(UTC) - if self._event_completed_signal: - self._event_completed_signal.set() + + def record_error(self, handler: Any, error: str) -> None: + """Record a handler error""" + self.event_errors[str(id(handler))] = error + + def mark_complete(self) -> None: + """Mark event as completed""" + self.event_completed_at = datetime.now(UTC) + if self._event_completed_signal: + self._event_completed_signal.set() def model_post_init(self, __context: Any) -> None: """Initialize completion event and set event schema after model creation""" @@ -75,3 +107,220 @@ class BaseEvent(BaseModel): self._event_completed_signal = asyncio.Event() except RuntimeError: self._event_completed_signal = None # Not in async context, skip + + +class EventResults: + """Lazy mapping of handler results that supports efficient slicing and aggregation""" + + def __init__(self, event: BaseEvent, eventbus: 'EventBus', timeout: float = 30.0, include_wildcards: bool = False): + self.event = event + self.eventbus = eventbus + self.timeout = timeout + try: + self._start_time = asyncio.get_event_loop().time() + except RuntimeError: + # No event loop in sync context, use time.time() + import time + self._start_time = time.time() + self.include_wildcards = include_wildcards + + # Only precompute the first handler for efficient .first() access + event_key = event.event_type + specific_handlers = self.eventbus.handlers.get(event_key, []) + wildcard_handlers = self.eventbus.handlers.get('*', []) if include_wildcards else [] + + local_handlers = specific_handlers + wildcard_handlers + self._first_handler = local_handlers[0] if local_handlers else None + + # Track EventBuses that have dispatched this event + self._seen_eventbus_ids: set[str] = {str(id(eventbus))} + + def __getitem__(self, key: str | Any): + """Get by handler ID, handler name, or handler function""" + if callable(key): + # Convert handler function to its ID + key = str(id(key)) + if not isinstance(key, str): + raise TypeError("EventResults accepts handler IDs (str), handler names (str), or handler functions") + + # First try as handler ID + if key in self.event.event_results: + return self.event.event_results[key] + + # Then try as handler name - find first matching handler + for handler_id, metadata in self.event._handler_metadata.items(): + if metadata['name'] == key and handler_id in self.event.event_results: + return self.event.event_results[handler_id] + + raise KeyError(f"No result found for key: {key}") + + def __await__(self): + """Default to by_handler_id() when awaited directly""" + return self.by_handler_id().__await__() + + async def _wait_for_all_handlers(self) -> None: + """Wait for all handlers across all seen EventBuses to complete""" + start_time = asyncio.get_event_loop().time() + + while True: + # Check if we've timed out + if asyncio.get_event_loop().time() - start_time > self.timeout: + logger.warning(f"Timeout waiting for all handlers after {self.timeout}s") + break + + # Get all eventbus IDs that have registered handlers + seen_handler_buses = set() + for handler_id, metadata in self.event._handler_metadata.items(): + if metadata.get('eventbus_id'): + seen_handler_buses.add(metadata['eventbus_id']) + + # Check if all seen buses have registered their handlers + all_buses_handled = self._seen_eventbus_ids.issubset(seen_handler_buses) + + # Check if all registered handlers have results or errors + all_handlers_done = all( + handler_id in self.event.event_results or handler_id in self.event.event_errors + for handler_id in self.event._handler_metadata + ) + + if all_buses_handled and all_handlers_done: + break + + # Small delay before checking again + await asyncio.sleep(0.01) + + async def _wait_for_handler(self, handler: Any) -> Any: + """Wait for a specific handler by id""" + handler_id = str(id(handler)) + while handler_id not in self.event.event_results and not self._is_timeout(): + await asyncio.sleep(0.001) + return self.event.event_results.get(handler_id) + + async def _get_all_results(self) -> dict[str, Any]: + """Wait for all handlers to complete and return all results""" + await self._wait_for_all_handlers() + return self.event.event_results + + async def first(self, default=None) -> Any: + """Get first handler result - only waits for the first handler""" + if not self._first_handler: + return default + return await self._wait_for_handler(self._first_handler) or default + + async def last(self) -> Any: + """Get last handler result - waits for all handlers to complete""" + await self._wait_for_all_handlers() + + # Get all handlers from metadata in order + handler_items = list(self.event._handler_metadata.items()) + + if handler_items: + # Get the last handler + handler_id = handler_items[-1][0] + return self.event.event_results.get(handler_id) + + return None + + + async def by_handler_name(self) -> dict[str, Any]: + """Get results keyed by handler name (last handler with each name wins)""" + await self._wait_for_all_handlers() + + results = {} + for handler_id, metadata in self.event._handler_metadata.items(): + if handler_id in self.event.event_results: + value = self.event.event_results[handler_id] + if value is not None: + results[metadata['name']] = value + return results + + async def by_handler_id(self) -> dict[str, Any]: + """Get results keyed by handler id string""" + await self._wait_for_all_handlers() + + results = {} + for handler_id in self.event.event_results: + value = self.event.event_results[handler_id] + if value is not None: + results[handler_id] = value + return results + + async def by_eventbus_id(self) -> dict[str, Any]: + """Get results keyed by eventbus id (last handler per eventbus wins)""" + await self._wait_for_all_handlers() + + results = {} + for handler_id, metadata in self.event._handler_metadata.items(): + if handler_id in self.event.event_results: + eventbus_id = metadata.get('eventbus_id') + value = self.event.event_results[handler_id] + if value is not None and eventbus_id: + results[eventbus_id] = value + return results + + async def by_path(self) -> dict[str, Any]: + """Get results keyed by path: eventbus_name#eventbus_id.handler_name""" + await self._wait_for_all_handlers() + + results = {} + for handler_id, metadata in self.event._handler_metadata.items(): + if handler_id in self.event.event_results: + if metadata.get('eventbus_name') and metadata.get('eventbus_id'): + path = f"{metadata['eventbus_name']}#{metadata['eventbus_id']}.{metadata['name']}" + value = self.event.event_results[handler_id] + if value is not None: + results[path] = value + return results + + async def values(self) -> list[Any]: + """Get all results as list""" + await self._wait_for_all_handlers() + + # Return results in handler registration order + results = [] + for handler_id in self.event._handler_metadata: + if handler_id in self.event.event_results: + value = self.event.event_results[handler_id] + if value is not None: + results.append(value) + return results + + async def flat_dict(self) -> dict[str, Any]: + """Merge results into single dict""" + await self._wait_for_all_handlers() + + merged = {} + for handler_id, metadata in self.event._handler_metadata.items(): + if handler_id in self.event.event_results: + result = self.event.event_results[handler_id] + if result is not None: + if not isinstance(result, dict): + handler_name = metadata.get('name', 'unknown') + raise TypeError(f"Handler '{handler_name}' returned {type(result).__name__} instead of dict") + merged.update(result) + return merged + + async def flat_list(self) -> list[Any]: + """Merge results into single list""" + await self._wait_for_all_handlers() + + merged = [] + for handler_id, metadata in self.event._handler_metadata.items(): + if handler_id in self.event.event_results: + result = self.event.event_results[handler_id] + if result is not None: + if not isinstance(result, list): + handler_name = metadata.get('name', 'unknown') + raise TypeError(f"Handler '{handler_name}' returned {type(result).__name__} instead of list") + merged.extend(result) + return merged + + def _is_timeout(self) -> bool: + """Check timeout""" + try: + current_time = asyncio.get_event_loop().time() + except RuntimeError: + # No event loop in sync context, use time.time() + import time + current_time = time.time() + return current_time - self._start_time > self.timeout diff --git a/browser_use/eventbus/service.py b/browser_use/eventbus/service.py index ddaa451cc..86bd35d84 100644 --- a/browser_use/eventbus/service.py +++ b/browser_use/eventbus/service.py @@ -1,8 +1,10 @@ import asyncio import inspect import logging +import warnings from collections import defaultdict from collections.abc import Awaitable, Callable +from contextvars import ContextVar, copy_context from datetime import UTC, datetime from pathlib import Path from typing import Any, Union @@ -11,13 +13,16 @@ import anyio from pydantic import BaseModel from uuid_extensions import uuid7str -from browser_use.eventbus.models import BaseEvent, UUIDStr +from browser_use.eventbus.models import BaseEvent, EventResults, UUIDStr logger = logging.getLogger(__name__) # Type alias for event handlers EventHandler = Union[Callable[[BaseEvent], Any], Callable[[BaseEvent], Awaitable[Any]]] +# Context variable to track the current event being processed +_current_event_context: ContextVar[BaseEvent | None] = ContextVar('current_event', default=None) + class EventBus: """ @@ -103,35 +108,69 @@ class EventBus: eventbus.on(TaskStartedEvent, handler) # Event model class eventbus.on('*', handler) # Subscribe to all events eventbus.on('*', other_eventbus.dispatch) # Forward all events to another EventBus + + Note: When forwarding events between buses, all handler results are automatically + flattened into the original event's results, so EventResults sees all handlers + from all buses as a single flat collection. """ - # Allow both sync and async handlers + # Determine event key if event_pattern == '*': - # Subscribe to all events using '*' as the key - self.handlers['*'].append(handler) + event_key = '*' elif isinstance(event_pattern, type) and issubclass(event_pattern, BaseModel): - # Subscribe by model class - self.handlers[event_pattern.__name__].append(handler) + event_key = event_pattern.__name__ else: - # Subscribe by string event type - self.handlers[str(event_pattern)].append(handler) + event_key = str(event_pattern) + + # Check for duplicate handler names + handler_name = handler.__name__ + existing_names = [h.__name__ for h in self.handlers.get(event_key, [])] + + if handler_name in existing_names: + warnings.warn( + f"⚠️ Handler '{handler_name}' already registered for event '{event_key}'. " + f"This may cause ambiguous results when using name-based access. " + f"Consider using unique function names.", + UserWarning, + stacklevel=2 + ) + + # Register handler + self.handlers[event_key].append(handler) + logger.debug(f"✅ {self} Registered handler {handler_name} for event {event_key}") + + # Auto-start if needed + if not self._is_running: + self._start() - def dispatch(self, event: BaseEvent) -> BaseEvent: + def dispatch(self, event: BaseEvent, timeout: float = 30.0) -> EventResults: """ - Enqueue an event for processing by the handlers. Returns awaitable event object. + Enqueue an event for processing and return EventResults for accessing responses. (Auto-starts the EventBus's async _run_loop() if not already running) - Similar to JS EventListener.dispatchEvent() or eventbus.dispatch() in other languages. + Returns EventResults which can be awaited directly or used to access specific results. + Access the original event via EventResults.event. """ assert event.event_id, 'Missing event.event_id: UUIDStr = uuid7str()' assert event.event_created_at, 'Missing event.queued_at: datetime = datetime.now(UTC)' assert event.event_type and event.event_type.isidentifier(), 'Missing event.event_type: str' assert event.event_schema and '@' in event.event_schema, 'Missing event.event_schema: str (with @version)' + # Automatically set parent_event_id from context if not already set + if event.parent_event_id is None: + current_event = _current_event_context.get() + if current_event is not None: + event.parent_event_id = current_event.event_id + # Add this EventBus to the event_path if not already there if self.name not in event.event_path: # preserve identity of the original object instead of creating a new one, so that the original object remains awaitable to get the result # NOT: event = event.model_copy(update={'event_path': event.event_path + [self.name]}) event.event_path.append(self.name) + + # Store reference to this EventBus for result() method + if not event._eventbus: + event._eventbus = self + assert event.event_path, 'Missing event.event_path: list[str] (with at least the origin function name recorded in it)' assert all(entry.isidentifier() for entry in event.event_path), ( @@ -150,7 +189,64 @@ class EventBus: except asyncio.QueueFull: logger.error(f'⚠️ {self} Event queue is full! Dropping event {event.event_type}:\n{event.model_dump_json()}') - return event + # Create or update the single EventResults instance + if not event.results: + event.results = EventResults(event, self, timeout) + else: + # Just update the seen eventbus IDs + event.results._seen_eventbus_ids.add(str(id(self))) + + return event.results + + async def expect( + self, + event_type: str | type[BaseModel], + timeout: float | None = None, + predicate: Callable[[BaseEvent], bool] | None = None, + ) -> BaseEvent: + """ + Wait for an event matching the given type/pattern with optional predicate filter. + + Args: + event_type: The event type string or model class to wait for + timeout: Maximum time to wait in seconds (None = wait forever) + predicate: Optional filter function that must return True for the event to match + + Returns: + The first matching event + + Raises: + asyncio.TimeoutError: If timeout is reached before a matching event + + Example: + # Wait for any response event + response = await eventbus.expect('ResponseEvent', timeout=30) + + # Wait for specific response with predicate + response = await eventbus.expect( + 'ResponseEvent', + predicate=lambda e: e.request_id == my_request_id, + timeout=30 + ) + """ + future: asyncio.Future[BaseEvent] = asyncio.Future() + + def temporary_handler(event: BaseEvent) -> None: + """Handler that resolves the future when a matching event is found""" + if not future.done() and (predicate is None or predicate(event)): + future.set_result(event) + + # Register temporary handler + self.on(event_type, temporary_handler) + + try: + # Wait for the future with optional timeout + return await asyncio.wait_for(future, timeout=timeout) + finally: + # Clean up handler + event_key = event_type.__name__ if isinstance(event_type, type) else str(event_type) + if event_key in self.handlers and temporary_handler in self.handlers[event_key]: + self.handlers[event_key].remove(temporary_handler) def _start(self) -> None: """Start the event bus if not already running""" @@ -244,8 +340,8 @@ class EventBus: await self._execute_handlers(event, handlers=applicable_handlers) - # Mark event as completed with empty results dict - event.record_results(complete=True) + # Mark event as completed + event.mark_complete() # Persist to WAL if configured if self.wal_path: @@ -266,14 +362,16 @@ class EventBus: # Add wildcard handlers (handlers registered for '*') applicable_handlers.extend(self.handlers.get('*', [])) - # Filter out handlers that would create loops and build name->handler mapping + # Filter out handlers that would create loops and build id->handler mapping + # Use handler id as key to preserve all handlers even with duplicate names filtered_handlers = {} for handler in applicable_handlers: if self._would_create_loop(event, handler): logger.debug(f'Skipping {handler.__name__} to prevent loop for {event.event_type}') continue else: - filtered_handlers[handler.__name__] = handler + handler_id = str(id(handler)) + filtered_handlers[handler_id] = handler return filtered_handlers @@ -286,46 +384,56 @@ class EventBus: # Execute all handlers in parallel if self.parallel_handlers: handler_tasks = {} - for handler_name, handler in applicable_handlers.items(): + for handler_id, handler in applicable_handlers.items(): task = asyncio.create_task(self._execute_sync_or_async_handler(event, handler)) - handler_tasks[handler_name] = task + handler_tasks[handler_id] = (task, handler) # Wait for all handlers to complete and record results incrementally - for handler_name, task in handler_tasks.items(): + for handler_id, (task, handler) in handler_tasks.items(): try: result = await task - event.record_results({handler_name: result}, complete=False) + event.record_result(handler, result, self) except Exception as e: - event.event_errors[handler_name] = str(e) + event.record_error(handler, str(e)) logger.error( - f'❌ {self} Handler {handler_name} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}' + f'❌ {self} Handler {handler.__name__} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}' ) else: # otherwise, execute handlers serially, wait until each one completes before moving on to the next - for handler_name, handler in applicable_handlers.items(): + for handler_id, handler in applicable_handlers.items(): try: result = await self._execute_sync_or_async_handler(event, handler) - event.record_results({handler_name: result}, complete=False) + event.record_result(handler, result, self) except Exception as e: - event.event_errors[handler_name] = str(e) + event.record_error(handler, str(e)) logger.error( - f'❌ {self} Handler {handler_name} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}' + f'❌ {self} Handler {handler.__name__} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}' ) async def _execute_sync_or_async_handler(self, event: BaseEvent, handler: EventHandler) -> Any: """Safely execute a single handler""" + # Set the current event in context so child events can reference it + token = _current_event_context.set(event) try: if inspect.iscoroutinefunction(handler): return await handler(event) else: - # Run sync handler in thread pool to avoid blocking + # Run sync handler in thread pool with context preserved loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, handler, event) + ctx = copy_context() + # Create a wrapper that preserves the context + def context_preserving_wrapper(): + # The context is already copied, just run the handler + return handler(event) + return await loop.run_in_executor(None, ctx.run, context_preserving_wrapper) except Exception as e: logger.exception( f'❌ {self} Error in handler {handler.__name__} for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}' ) raise + finally: + # Reset context + _current_event_context.reset(token) @staticmethod def _would_create_loop(event: BaseEvent, handler: EventHandler) -> bool: diff --git a/tests/ci/test_eventbus.py b/tests/ci/test_eventbus.py index 90c8fd9b1..6f687fcf8 100644 --- a/tests/ci/test_eventbus.py +++ b/tests/ci/test_eventbus.py @@ -24,6 +24,7 @@ from pydantic import Field from browser_use.agent.cloud_events import CreateAgentTaskEvent from browser_use.eventbus import BaseEvent, EventBus +from browser_use.eventbus.models import EventResults # Test event models @@ -422,7 +423,7 @@ class TestEdgeCases: tasks = [] for i in range(100): event = UserActionEvent(action=f'concurrent_{i}', user_id='u1') - # Emit returns the event synchronously, but we need to wait for completion + # Emit returns the event syncresultsonously, but we need to wait for completion emitted_event = eventbus.dispatch(event) tasks.append(emitted_event.result()) @@ -618,9 +619,9 @@ class TestWALPersistence: class TestEventBusHierarchy: """Test hierarchical EventBus subscription patterns""" - async def test_three_level_hierarchy_bubbling(self): - """Test that events bubble up through a 3-level hierarchy and event_path is correct""" - # Create three EventBus instances in a hierarchy + async def test_tresultsee_level_hierarchy_bubbling(self): + """Test that events bubble up tresultsough a 3-level hierarchy and event_path is correct""" + # Create tresultsee EventBus instances in a hierarchy parent_bus = EventBus(name='ParentBus') child_bus = EventBus(name='ChildBus') subchild_bus = EventBus(name='SubchildBus') @@ -701,7 +702,7 @@ class TestEventBusHierarchy: async def test_circular_subscription_prevention(self): """Test that circular EventBus subscriptions don't create infinite loops""" - # Create three peer EventBus instances + # Create tresultsee peer EventBus instances peer1 = EventBus(name='Peer1') peer2 = EventBus(name='Peer2') peer3 = EventBus(name='Peer3') @@ -788,5 +789,675 @@ class TestEventBusHierarchy: await peer3.stop() +class TestExpectMethod: + """Test the expect() method functionality""" + + async def test_expect_basic(self, eventbus): + """Test basic expect functionality""" + # Start waiting for an event that hasn't been dispatched yet + expect_task = asyncio.create_task( + eventbus.expect('UserActionEvent', timeout=1.0) + ) + + # Give expect time to register handler + await asyncio.sleep(0.01) + + # Dispatch the event + dispatched = eventbus.dispatch(UserActionEvent(action='login', user_id='user123')) + + # Wait for expect to resolve + received = await expect_task + + # Verify we got the right event + assert received.event_type == 'UserActionEvent' + assert received.action == 'login' + assert received.user_id == 'user123' + assert received.event_id == dispatched.event_id + + async def test_expect_with_predicate(self, eventbus): + """Test expect with predicate filtering""" + # Dispatch some events that don't match + eventbus.dispatch(UserActionEvent(action='logout', user_id='user456')) + eventbus.dispatch(UserActionEvent(action='login', user_id='user789')) + + # Start expecting with predicate + expect_task = asyncio.create_task( + eventbus.expect( + 'UserActionEvent', + predicate=lambda e: e.user_id == 'user123', + timeout=1.0 + ) + ) + + # Give expect time to register + await asyncio.sleep(0.01) + + # Dispatch more events + eventbus.dispatch(UserActionEvent(action='update', user_id='user456')) + target_event = eventbus.dispatch(UserActionEvent(action='login', user_id='user123')) + eventbus.dispatch(UserActionEvent(action='delete', user_id='user789')) + + # Wait for the matching event + received = await expect_task + + # Should get the event matching the predicate + assert received.user_id == 'user123' + assert received.event_id == target_event.event_id + + async def test_expect_timeout(self, eventbus): + """Test expect timeout behavior""" + # Expect an event that will never come + with pytest.raises(asyncio.TimeoutError): + await eventbus.expect('NonExistentEvent', timeout=0.1) + + async def test_expect_with_model_class(self, eventbus): + """Test expect with model class instead of string""" + # Start expecting by model class + expect_task = asyncio.create_task( + eventbus.expect(SystemEventModel, timeout=1.0) + ) + + await asyncio.sleep(0.01) + + # Dispatch different event types + eventbus.dispatch(UserActionEvent(action='test', user_id='u1')) + target = eventbus.dispatch(SystemEventModel(event_name='startup', severity='info')) + + # Should receive the SystemEventModel + received = await expect_task + assert isinstance(received, SystemEventModel) + assert received.event_name == 'startup' + assert received.event_id == target.event_id + + async def test_multiple_concurrent_expects(self, eventbus): + """Test multiple concurrent expect calls""" + # Set up multiple expects for different events + expect1 = asyncio.create_task( + eventbus.expect( + 'UserActionEvent', + predicate=lambda e: e.action == 'normal', + timeout=2.0 + ) + ) + expect2 = asyncio.create_task( + eventbus.expect('SystemEventModel', timeout=2.0) + ) + expect3 = asyncio.create_task( + eventbus.expect( + 'UserActionEvent', + predicate=lambda e: e.action == 'special', + timeout=2.0 + ) + ) + + await asyncio.sleep(0.1) # Give more time for handlers to register + + # Dispatch events + e1 = eventbus.dispatch(UserActionEvent(action='normal', user_id='u1')) + e2 = eventbus.dispatch(SystemEventModel(event_name='test')) + e3 = eventbus.dispatch(UserActionEvent(action='special', user_id='u2')) + + # Wait for all events to be processed + await eventbus.wait_until_idle() + + # Wait for all expects + r1, r2, r3 = await asyncio.gather(expect1, expect2, expect3) + + # Verify results + assert r1.event_id == e1.event_id # Normal UserActionEvent + assert r2.event_id == e2.event_id # SystemEventModel + assert r3.event_id == e3.event_id # Special UserActionEvent + + async def test_expect_handler_cleanup(self, eventbus): + """Test that temporary handlers are properly cleaned up""" + # Check initial handler count + initial_handlers = len(eventbus.handlers.get('TestEvent', [])) + + # Create an expect that times out + try: + await eventbus.expect('TestEvent', timeout=0.1) + except asyncio.TimeoutError: + pass + + # Handler should be cleaned up + assert len(eventbus.handlers.get('TestEvent', [])) == initial_handlers + + # Create an expect that succeeds + expect_task = asyncio.create_task( + eventbus.expect('TestEvent2', timeout=1.0) + ) + await asyncio.sleep(0.01) + eventbus.dispatch(BaseEvent(event_type='TestEvent2')) + await expect_task + + # Handler should be cleaned up + assert len(eventbus.handlers.get('TestEvent2', [])) == 0 + + async def test_expect_receives_completed_event(self, eventbus): + """Test that expect receives events after they're fully processed""" + processing_complete = False + + async def slow_handler(event: BaseEvent) -> str: + await asyncio.sleep(0.1) + nonlocal processing_complete + processing_complete = True + return 'done' + + # Register a slow handler + eventbus.on('SlowEvent', slow_handler) + + # Start expecting + expect_task = asyncio.create_task( + eventbus.expect('SlowEvent', timeout=1.0) + ) + + await asyncio.sleep(0.01) + + # Dispatch event + eventbus.dispatch(BaseEvent(event_type='SlowEvent')) + + # Wait for expect + received = await expect_task + + # At this point, the slow handler should have run + # but we receive the event as soon as it matches + assert received.event_type == 'SlowEvent' + # The event might not be fully completed yet since expect + # triggers as soon as the event is processed by its handler + + async def test_expect_with_complex_predicate(self, eventbus): + """Test expect with complex predicate logic""" + events_seen = [] + + def complex_predicate(event: UserActionEvent) -> bool: + events_seen.append(event.action) + # Only match after seeing at least 3 events and action is 'target' + return len(events_seen) >= 3 and event.action == 'target' + + expect_task = asyncio.create_task( + eventbus.expect( + 'UserActionEvent', + predicate=complex_predicate, + timeout=1.0 + ) + ) + + await asyncio.sleep(0.01) + + # Dispatch events + eventbus.dispatch(UserActionEvent(action='first', user_id='u1')) + eventbus.dispatch(UserActionEvent(action='second', user_id='u2')) + eventbus.dispatch(UserActionEvent(action='target', user_id='u3')) # Won't match yet + eventbus.dispatch(UserActionEvent(action='target', user_id='u4')) # This should match + + received = await expect_task + + assert received.user_id == 'u4' + assert len(events_seen) == 4 + + async def test_expect_in_sync_context(self, mock_agent): + """Test that expect can be used from sync code that later awaits""" + bus = EventBus() + + # This simulates calling expect from sync code + expect_coroutine = bus.expect('SyncEvent', timeout=1.0) + + # Dispatch event + bus.dispatch(BaseEvent(event_type='SyncEvent')) + + # Later await the coroutine + result = await expect_coroutine + assert result.event_type == 'SyncEvent' + + await bus.stop() + + +class TestEventResults: + """Test the new EventResults functionality""" + + async def test_dispatch_returns_event_results(self, eventbus): + """Test that dispatch now returns EventResults""" + # Register a specific handler + async def test_handler(event): + return 'test_result' + + eventbus.on('UserActionEvent', test_handler) + + result = eventbus.dispatch(UserActionEvent(action='test', user_id='u1')) + assert isinstance(result, EventResults) + + # Can be awaited directly (returns by_handler_id dict) + all_results = await result + assert isinstance(all_results, dict) + # Should contain both test_handler and default log handler results + assert len(all_results) == 2 + assert 'test_result' in all_results.values() + assert 'logged' in all_results.values() + + # Test with no specific handlers (only wildcard) + result_no_handlers = eventbus.dispatch(BaseEvent(event_type='NoHandlersEvent')) + assert await result_no_handlers.first() is None # No specific handlers + + # Test including wildcards explicitly + result_with_wildcards = EventResults(result_no_handlers.event, eventbus, include_wildcards=True) + assert await result_with_wildcards.first() == 'logged' # Wildcard handler result + + async def test_event_results_indexing(self, eventbus): + """Test indexing by handler name and ID""" + order = [] + + async def handler1(event): + order.append(1) + return 'first' + + async def handler2(event): + order.append(2) + return 'second' + + async def handler3(event): + order.append(3) + return 'third' + + eventbus.on('TestEvent', handler1) + eventbus.on('TestEvent', handler2) + eventbus.on('TestEvent', handler3) + + # Test indexing + hr = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + + # Wait for all handlers to complete + await hr + + # Index by handler name + assert hr['handler1'] == 'first' + assert hr['handler2'] == 'second' + assert hr['handler3'] == 'third' + + # Index by handler function + assert hr[handler1] == 'first' + assert hr[handler2] == 'second' + assert hr[handler3] == 'third' + + async def test_first_last_methods(self, eventbus): + """Test first() and last() methods""" + async def early_handler(event): + return 'early' + + async def late_handler(event): + await asyncio.sleep(0.01) + return 'late' + + eventbus.on('TestEvent', early_handler) + eventbus.on('TestEvent', late_handler) + + results = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + + # first() returns first handler result + assert await results.first() == 'early' + + # last() returns last handler result + assert await results.last() == 'late' + + # With empty handlers + eventbus.handlers['EmptyEvent'] = [] + results_empty = eventbus.dispatch(BaseEvent(event_type='EmptyEvent')) + assert await results_empty.first(default='none') == 'none' + assert await results_empty.last() is None + + async def test_by_handler_name(self, eventbus): + """Test by_handler_name() with duplicate names""" + async def process_data(event): + return 'version1' + + async def process_data(event): # Same name! + return 'version2' + + async def unique_handler(event): + return 'unique' + + # Should get warning about duplicate name + with pytest.warns(UserWarning, match='already registered'): + eventbus.on('TestEvent', process_data) + eventbus.on('TestEvent', process_data) + eventbus.on('TestEvent', unique_handler) + + results = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + results = await results.by_handler_name() + + # Last handler with same name wins + assert results['process_data'] == 'version2' + assert results['unique_handler'] == 'unique' + # Default log handler is included as a wildcard handler + assert '_default_log_handler' in results + assert len(results) == 3 # 2 test handlers + 1 default log handler + + async def test_by_handler_id(self, eventbus): + """Test by_handler_id() returns all handlers uniquely""" + async def handler1(event): + return 'v1' + + async def handler2(event): + return 'v2' + + # Give them the same name for the test + handler1.__name__ = 'handler' + handler2.__name__ = 'handler' + + eventbus.on('TestEvent', handler1) + eventbus.on('TestEvent', handler2) + + results = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + results = await results.by_handler_id() + + # All handlers present with unique IDs even with same name + assert len(results) == 2 + assert 'v1' in results.values() + assert 'v2' in results.values() + + async def test_flat_dict(self, eventbus): + """Test flat_dict() merging""" + async def config_base(event): + return {'debug': False, 'port': 8080, 'name': 'base'} + + async def config_override(event): + return {'debug': True, 'timeout': 30, 'name': 'override'} + + eventbus.on('GetConfig', config_base) + eventbus.on('GetConfig', config_override) + + results = eventbus.dispatch(BaseEvent(event_type='GetConfig')) + merged = await results.flat_dict() + + # Later handlers override earlier ones + assert merged == { + 'debug': True, # Overridden + 'port': 8080, # From base + 'timeout': 30, # From override + 'name': 'override' # Overridden + } + + # Test type error + async def bad_handler(event): + return 'not a dict' + + eventbus.on('BadConfig', bad_handler) + results_bad = eventbus.dispatch(BaseEvent(event_type='BadConfig')) + + with pytest.raises(TypeError, match='returned str instead of dict'): + await results_bad.flat_dict() + + async def test_flat_list(self, eventbus): + """Test flat_list() concatenation""" + async def errors1(event): + return ['error1', 'error2'] + + async def errors2(event): + return ['error3'] + + async def errors3(event): + return ['error4', 'error5'] + + eventbus.on('GetErrors', errors1) + eventbus.on('GetErrors', errors2) + eventbus.on('GetErrors', errors3) + + results = eventbus.dispatch(BaseEvent(event_type='GetErrors')) + all_errors = await results.flat_list() + + assert all_errors == ['error1', 'error2', 'error3', 'error4', 'error5'] + + + # Test type error + async def bad_handler(event): + return {'not': 'a list'} + + eventbus.on('BadList', bad_handler) + results_bad = eventbus.dispatch(BaseEvent(event_type='BadList')) + + with pytest.raises(TypeError, match='returned dict instead of list'): + await results_bad.flat_list() + + async def test_by_handler_name_access(self, eventbus): + """Test by_handler_name() method for name-based access""" + async def handler_a(event): + return 'result_a' + + async def handler_b(event): + return 'result_b' + + eventbus.on('TestEvent', handler_a) + eventbus.on('TestEvent', handler_b) + + results = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + + by_name = await results.by_handler_name() + assert by_name.get('handler_a') == 'result_a' + assert by_name.get('handler_b') == 'result_b' + assert by_name.get('nonexistent', 'fallback') == 'fallback' + + async def test_string_indexing(self, eventbus): + """Test string indexing for handler access""" + async def my_handler(event): + return 'my_result' + + eventbus.on('TestEvent', my_handler) + results = eventbus.dispatch(BaseEvent(event_type='TestEvent')) + + # Wait for handlers to complete + await results + + # String indexing by handler name + assert results['my_handler'] == 'my_result' + + # Missing key raises KeyError + with pytest.raises(KeyError, match='No result found for key: missing'): + results['missing'] + + +class TestEventBusForwarding: + """Test event forwarding between buses with new EventResults""" + + async def test_forwarding_flattens_results(self): + """Test that forwarding events between buses flattens all results""" + bus1 = EventBus(name='Bus1') + bus2 = EventBus(name='Bus2') + bus3 = EventBus(name='Bus3') + + results = [] + + async def bus1_handler(event): + results.append('bus1') + return 'from_bus1' + + async def bus2_handler(event): + results.append('bus2') + return 'from_bus2' + + async def bus3_handler(event): + results.append('bus3') + return 'from_bus3' + + # Register handlers + bus1.on('TestEvent', bus1_handler) + bus2.on('TestEvent', bus2_handler) + bus3.on('TestEvent', bus3_handler) + + # Set up forwarding chain + bus1.on('*', bus2.dispatch) + bus2.on('*', bus3.dispatch) + + try: + # Dispatch from bus1 + results = bus1.dispatch(BaseEvent(event_type='TestEvent')) + + # Wait for all buses to complete processing + await bus1.wait_until_idle() + await bus2.wait_until_idle() + await bus3.wait_until_idle() + + # All handlers from all buses should be visible + all_results = await results.by_handler_name() + assert 'bus1_handler' in all_results + assert 'bus2_handler' in all_results + assert 'bus3_handler' in all_results + + # Results should be flattened + assert all_results['bus1_handler'] == 'from_bus1' + assert all_results['bus2_handler'] == 'from_bus2' + assert all_results['bus3_handler'] == 'from_bus3' + + # Check execution order + assert results == ['bus1', 'bus2', 'bus3'] + + finally: + await bus1.stop() + await bus2.stop() + await bus3.stop() + + async def test_by_eventbus_id_and_path(self): + """Test by_eventbus_id() and by_path() with forwarding""" + bus1 = EventBus(name='MainBus') + bus2 = EventBus(name='PluginBus') + + async def main_handler(event): + return 'main_result' + + async def plugin_handler1(event): + return 'plugin_result1' + + async def plugin_handler2(event): + return 'plugin_result2' + + bus1.on('DataEvent', main_handler) + bus2.on('DataEvent', plugin_handler1) + bus2.on('DataEvent', plugin_handler2) + + # Forward from bus1 to bus2 + bus1.on('*', bus2.dispatch) + + try: + results = bus1.dispatch(BaseEvent(event_type='DataEvent')) + + # Test by_eventbus_id + by_bus = await results.by_eventbus_id() + assert len(by_bus) == 2 # One per bus + assert str(id(bus1)) in by_bus + assert str(id(bus2)) in by_bus + # Last handler per bus wins + assert by_bus[str(id(bus2))] == 'plugin_result2' + + # Test by_path + by_path = await results.by_path() + assert f'MainBus#{id(bus1)}.main_handler' in by_path + assert f'PluginBus#{id(bus2)}.plugin_handler1' in by_path + assert f'PluginBus#{id(bus2)}.plugin_handler2' in by_path + + finally: + await bus1.stop() + await bus2.stop() + + +class TestComplexIntegration: + """Complex integration test with all features""" + + async def test_complex_multi_bus_scenario(self): + """Test complex scenario with multiple buses, duplicate names, and all query methods""" + # Create a hierarchy of buses + app_bus = EventBus(name='AppBus') + auth_bus = EventBus(name='AuthBus') + data_bus = EventBus(name='DataBus') + + # Handlers with conflicting names + async def validate(event): + """App validation""" + return {'app_valid': True, 'timestamp': 1000} + + async def validate(event): + """Auth validation""" + return {'auth_valid': True, 'user': 'alice'} + + async def validate(event): + """Data validation""" + return {'data_valid': True, 'schema': 'v2'} + + async def process(event): + """Auth processing""" + return ['auth_log_1', 'auth_log_2'] + + async def process(event): + """Data processing""" + return ['data_log_1', 'data_log_2', 'data_log_3'] + + # Register handlers with same names on different buses + app_bus.on('ValidationRequest', validate) + auth_bus.on('ValidationRequest', validate) + auth_bus.on('ValidationRequest', process) # Different return type! + data_bus.on('ValidationRequest', validate) + data_bus.on('ValidationRequest', process) + + # Set up forwarding + app_bus.on('*', auth_bus.dispatch) + auth_bus.on('*', data_bus.dispatch) + + try: + # Dispatch event + results = app_bus.dispatch(BaseEvent(event_type='ValidationRequest')) + + # Test all access methods + + # 1. Direct await (first result) + first = await results + assert first == 'logged' # Default logger + + # 2. Slicing + validation_results = await results[1:4].values() + assert len(validation_results) == 3 + + # 3. by_handler_name - duplicates overwrite + by_name = await results.by_handler_name() + assert by_name['validate'] == {'data_valid': True, 'schema': 'v2'} # Last wins + assert by_name['process'] == ['data_log_1', 'data_log_2', 'data_log_3'] # Last wins + + # 4. by_handler_id - all unique + by_id = await results.by_handler_id() + assert len(by_id) >= 5 # At least 5 handlers + + # 5. flat_dict - only dict results + dict_handlers = [h for h in by_id.values() if isinstance(h, dict)] + assert len(dict_handlers) >= 3 # 3 validate handlers + + # 6. flat_list - only list results + list_handlers = [h for h in by_id.values() if isinstance(h, list)] + assert len(list_handlers) >= 2 # 2 process handlers + + # 7. Mixed flat operations with slicing + # Skip default logger and get only validation dicts + with pytest.raises(TypeError, match='instead of dict'): + # This should fail because we're mixing dicts and lists + await results[1:].flat_dict() + + # 8. by_eventbus_id + by_bus = await results.by_eventbus_id() + assert len(by_bus) == 3 # One result per bus + + # 9. by_path for full traceability + by_path = await results.by_path() + paths = list(by_path.keys()) + assert any('AppBus#' in p and '.validate' in p for p in paths) + assert any('AuthBus#' in p and '.validate' in p for p in paths) + assert any('AuthBus#' in p and '.process' in p for p in paths) + assert any('DataBus#' in p and '.validate' in p for p in paths) + assert any('DataBus#' in p and '.process' in p for p in paths) + + # 10. Test handlers property + assert 'validate' in results.handlers + assert 'process' in results.handlers + + finally: + await app_bus.stop() + await auth_bus.stop() + await data_bus.stop() + + if __name__ == '__main__': pytest.main([__file__, '-v']) diff --git a/tests/ci/test_parent_event_tracking.py b/tests/ci/test_parent_event_tracking.py new file mode 100644 index 000000000..c23fa708d --- /dev/null +++ b/tests/ci/test_parent_event_tracking.py @@ -0,0 +1,285 @@ +""" +Test parent event tracking functionality in EventBus. +""" + +import asyncio + +import pytest + +from browser_use.eventbus import BaseEvent, EventBus + + +class ParentEvent(BaseEvent): + """Parent event that triggers child events""" + event_type: str = 'ParentEvent' + message: str + + +class ChildEvent(BaseEvent): + """Child event triggered by parent""" + event_type: str = 'ChildEvent' + data: str + + +class GrandchildEvent(BaseEvent): + """Grandchild event triggered by child""" + event_type: str = 'GrandchildEvent' + value: int + + +@pytest.fixture +async def eventbus(): + """Create an event bus for testing""" + bus = EventBus(name='TestBus') + yield bus + await bus.stop() + + +class TestParentEventTracking: + """Test automatic parent event ID tracking""" + + async def test_basic_parent_tracking(self, eventbus): + """Test that child events automatically get parent_event_id""" + child_events = [] + + async def parent_handler(event: ParentEvent) -> str: + # Handler that dispatches a child event + child = ChildEvent(data=f"child_of_{event.message}") + eventbus.dispatch(child) + child_events.append(child) + return 'parent_handled' + + eventbus.on('ParentEvent', parent_handler) + + # Dispatch parent event + parent = ParentEvent(message='test_parent') + parent_result = eventbus.dispatch(parent) + + # Wait for processing + await eventbus.wait_until_idle() + + # Verify parent processed + assert await parent_result.get('parent_handler') == 'parent_handled' + + # Verify child has parent_event_id set + assert len(child_events) == 1 + child = child_events[0] + assert child.parent_event_id == parent.event_id + + async def test_multi_level_parent_tracking(self, eventbus): + """Test parent tracking across multiple levels""" + events_by_level = {'parent': None, 'child': None, 'grandchild': None} + + async def parent_handler(event: ParentEvent) -> str: + events_by_level['parent'] = event + child = ChildEvent(data='child_data') + eventbus.dispatch(child) + return 'parent' + + async def child_handler(event: ChildEvent) -> str: + events_by_level['child'] = event + grandchild = GrandchildEvent(value=42) + eventbus.dispatch(grandchild) + return 'child' + + async def grandchild_handler(event: GrandchildEvent) -> str: + events_by_level['grandchild'] = event + return 'grandchild' + + # Register handlers + eventbus.on('ParentEvent', parent_handler) + eventbus.on('ChildEvent', child_handler) + eventbus.on('GrandchildEvent', grandchild_handler) + + # Start the chain + parent = ParentEvent(message='root') + eventbus.dispatch(parent) + + # Wait for all processing + await eventbus.wait_until_idle() + + # Verify the parent chain + assert events_by_level['parent'].parent_event_id is None # Root has no parent + assert events_by_level['child'].parent_event_id == parent.event_id + assert events_by_level['grandchild'].parent_event_id == events_by_level['child'].event_id + + async def test_multiple_children_same_parent(self, eventbus): + """Test multiple child events from same parent""" + child_events = [] + + async def parent_handler(event: ParentEvent) -> str: + # Dispatch multiple children + for i in range(3): + child = ChildEvent(data=f"child_{i}") + eventbus.dispatch(child) + child_events.append(child) + return 'spawned_children' + + eventbus.on('ParentEvent', parent_handler) + + # Dispatch parent + parent = ParentEvent(message='multi_child_parent') + eventbus.dispatch(parent) + + await eventbus.wait_until_idle() + + # All children should have same parent + assert len(child_events) == 3 + for child in child_events: + assert child.parent_event_id == parent.event_id + + async def test_parallel_handlers_parent_tracking(self, eventbus): + """Test parent tracking with parallel handlers""" + events_from_handlers = {'h1': [], 'h2': []} + + async def handler1(event: ParentEvent) -> str: + await asyncio.sleep(0.01) # Simulate work + child = ChildEvent(data='from_h1') + eventbus.dispatch(child) + events_from_handlers['h1'].append(child) + return 'h1' + + async def handler2(event: ParentEvent) -> str: + await asyncio.sleep(0.02) # Different timing + child = ChildEvent(data='from_h2') + eventbus.dispatch(child) + events_from_handlers['h2'].append(child) + return 'h2' + + # Both handlers respond to same event + eventbus.on('ParentEvent', handler1) + eventbus.on('ParentEvent', handler2) + + # Dispatch parent + parent = ParentEvent(message='parallel_test') + eventbus.dispatch(parent) + + await eventbus.wait_until_idle() + + # Both children should have same parent despite parallel execution + assert len(events_from_handlers['h1']) == 1 + assert len(events_from_handlers['h2']) == 1 + assert events_from_handlers['h1'][0].parent_event_id == parent.event_id + assert events_from_handlers['h2'][0].parent_event_id == parent.event_id + + async def test_explicit_parent_not_overridden(self, eventbus): + """Test that explicitly set parent_event_id is not overridden""" + captured_child = None + + async def parent_handler(event: ParentEvent) -> str: + nonlocal captured_child + # Create child with explicit parent_event_id + child = ChildEvent(data='explicit', parent_event_id='explicit_parent_id') + eventbus.dispatch(child) + captured_child = child + return 'dispatched' + + eventbus.on('ParentEvent', parent_handler) + + parent = ParentEvent(message='test') + eventbus.dispatch(parent) + + await eventbus.wait_until_idle() + + # Explicit parent_event_id should be preserved + assert captured_child is not None + assert captured_child.parent_event_id == 'explicit_parent_id' + assert captured_child.parent_event_id != parent.event_id + + async def test_cross_eventbus_parent_tracking(self): + """Test parent tracking across multiple EventBuses""" + bus1 = EventBus(name='Bus1') + bus2 = EventBus(name='Bus2') + + captured_events = [] + + async def bus1_handler(event: ParentEvent) -> str: + # Dispatch child to bus2 + child = ChildEvent(data='cross_bus_child') + bus2.dispatch(child) + captured_events.append(('bus1', event, child)) + return 'bus1_handled' + + async def bus2_handler(event: ChildEvent) -> str: + captured_events.append(('bus2', event)) + return 'bus2_handled' + + bus1.on('ParentEvent', bus1_handler) + bus2.on('ChildEvent', bus2_handler) + + try: + # Dispatch parent to bus1 + parent = ParentEvent(message='cross_bus_test') + bus1.dispatch(parent) + + await bus1.wait_until_idle() + await bus2.wait_until_idle() + + # Verify parent tracking works across buses + assert len(captured_events) == 2 + _, parent_event, child_event = captured_events[0] + _, received_child = captured_events[1] + + assert child_event.parent_event_id == parent.event_id + assert received_child.parent_event_id == parent.event_id + + finally: + await bus1.stop() + await bus2.stop() + + async def test_sync_handler_parent_tracking(self, eventbus): + """Test parent tracking works with sync handlers""" + child_events = [] + + def sync_parent_handler(event: ParentEvent) -> str: + # Sync handler that dispatches child + child = ChildEvent(data='from_sync') + eventbus.dispatch(child) + child_events.append(child) + return 'sync_handled' + + eventbus.on('ParentEvent', sync_parent_handler) + + parent = ParentEvent(message='sync_test') + eventbus.dispatch(parent) + + await eventbus.wait_until_idle() + + # Parent tracking should work even with sync handlers + assert len(child_events) == 1 + assert child_events[0].parent_event_id == parent.event_id + + async def test_error_handler_parent_tracking(self, eventbus): + """Test parent tracking when handler errors occur""" + child_events = [] + + async def failing_handler(event: ParentEvent) -> str: + # Dispatch child before failing + child = ChildEvent(data='before_error') + eventbus.dispatch(child) + child_events.append(child) + raise ValueError("Handler failed") + + async def success_handler(event: ParentEvent) -> str: + # This should still run + child = ChildEvent(data='after_error') + eventbus.dispatch(child) + child_events.append(child) + return 'success' + + eventbus.on('ParentEvent', failing_handler) + eventbus.on('ParentEvent', success_handler) + + parent = ParentEvent(message='error_test') + eventbus.dispatch(parent) + + await eventbus.wait_until_idle() + + # Both children should have parent_event_id despite error + assert len(child_events) == 2 + for child in child_events: + assert child.parent_event_id == parent.event_id + + +if __name__ == '__main__': + pytest.main([__file__, '-v']) \ No newline at end of file