mirror of
https://github.com/browser-use/browser-use
synced 2026-05-06 17:52:15 +02:00
wip eventbus with EventResults aggregate for all results
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, PrivateAttr
|
||||
@@ -8,6 +10,11 @@ from uuid_extensions import uuid7str
|
||||
|
||||
from browser_use.utils import get_browser_use_version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from browser_use.eventbus.service import EventBus
|
||||
|
||||
# Constants for validation
|
||||
MAX_STRING_LENGTH = 10000 # 10K chars for most strings
|
||||
MAX_URL_LENGTH = 2000
|
||||
@@ -22,11 +29,14 @@ class BaseEvent(BaseModel):
|
||||
"""
|
||||
The base model used for all Events that flow through the EventBus system.
|
||||
"""
|
||||
|
||||
model_config = {'arbitrary_types_allowed': True}
|
||||
|
||||
event_schema: str | None = Field(default=None, description='Event schema version in format ClassName@version', max_length=100)
|
||||
event_type: str
|
||||
event_id: str = Field(default_factory=uuid7str)
|
||||
event_path: list[str] = Field(default_factory=list, description='Path tracking for event routing')
|
||||
parent_event_id: str | None = Field(default=None, description='ID of the parent event that triggered this event')
|
||||
|
||||
# Completion tracking fields
|
||||
event_created_at: datetime = Field(
|
||||
@@ -34,34 +44,56 @@ class BaseEvent(BaseModel):
|
||||
)
|
||||
event_started_at: datetime | None = Field(default=None, description='Timestamp when event was started')
|
||||
event_completed_at: datetime | None = Field(default=None, description='Timestamp when event was completed')
|
||||
# Store results by handler id to avoid name clashes
|
||||
event_results: dict[str, Any] = Field(
|
||||
default_factory=dict, exclude=True, description='Handler results {handler_name: result}'
|
||||
default_factory=dict, exclude=True, description='Handler results {str(id(handler)): result}'
|
||||
)
|
||||
event_errors: dict[str, str] = Field(
|
||||
default_factory=dict, exclude=True, description='Handler errors {handler_name: error_str}'
|
||||
default_factory=dict, exclude=True, description='Handler errors {str(id(handler)): error_str}'
|
||||
)
|
||||
# Map handler ids to metadata for result grouping
|
||||
_handler_metadata: dict[str, dict[str, Any]] = PrivateAttr(default_factory=dict)
|
||||
_event_completed_signal: asyncio.Event | None = PrivateAttr(default=None)
|
||||
_eventbus: 'EventBus | None' = PrivateAttr(default=None)
|
||||
results: 'EventResults | None' = Field(default=None, exclude=True)
|
||||
|
||||
@property
|
||||
def state(self) -> str:
|
||||
return 'completed' if self.event_completed_at else 'started' if self.event_started_at else 'queued'
|
||||
|
||||
async def result(self):
|
||||
"""Wait for completion and return self with results"""
|
||||
if self._event_completed_signal:
|
||||
await self._event_completed_signal.wait()
|
||||
return self
|
||||
def result(self, timeout: float = 30.0) -> 'EventResults':
|
||||
"""Return the EventResults object for accessing handler results"""
|
||||
if not self.results:
|
||||
raise RuntimeError("Event must be dispatched through an EventBus to use .result()")
|
||||
return self.results
|
||||
|
||||
def record_results(self, results: dict[str, Any] | None = None, complete: bool = True) -> None:
|
||||
"""Update the event results and optionally mark it as completed"""
|
||||
self.event_results = {
|
||||
**(self.event_results or {}),
|
||||
**(results or {}),
|
||||
def record_result(self, handler: Any, result: Any, eventbus: 'EventBus | None' = None) -> None:
|
||||
"""Record a handler result with metadata"""
|
||||
# Special handling for EventResults from forwarded events
|
||||
if isinstance(result, EventResults):
|
||||
# Don't record the EventResults object itself
|
||||
# The forwarded event's results are already recorded in the same event
|
||||
return
|
||||
|
||||
handler_id = str(id(handler))
|
||||
self.event_results[handler_id] = result
|
||||
self._handler_metadata[handler_id] = {
|
||||
'handler': handler,
|
||||
'name': handler.__name__,
|
||||
'eventbus': eventbus,
|
||||
'eventbus_name': eventbus.name if eventbus else None,
|
||||
'eventbus_id': str(id(eventbus)) if eventbus else None
|
||||
}
|
||||
if complete:
|
||||
self.event_completed_at = datetime.now(UTC)
|
||||
if self._event_completed_signal:
|
||||
self._event_completed_signal.set()
|
||||
|
||||
def record_error(self, handler: Any, error: str) -> None:
|
||||
"""Record a handler error"""
|
||||
self.event_errors[str(id(handler))] = error
|
||||
|
||||
def mark_complete(self) -> None:
|
||||
"""Mark event as completed"""
|
||||
self.event_completed_at = datetime.now(UTC)
|
||||
if self._event_completed_signal:
|
||||
self._event_completed_signal.set()
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
"""Initialize completion event and set event schema after model creation"""
|
||||
@@ -75,3 +107,220 @@ class BaseEvent(BaseModel):
|
||||
self._event_completed_signal = asyncio.Event()
|
||||
except RuntimeError:
|
||||
self._event_completed_signal = None # Not in async context, skip
|
||||
|
||||
|
||||
class EventResults:
|
||||
"""Lazy mapping of handler results that supports efficient slicing and aggregation"""
|
||||
|
||||
def __init__(self, event: BaseEvent, eventbus: 'EventBus', timeout: float = 30.0, include_wildcards: bool = False):
|
||||
self.event = event
|
||||
self.eventbus = eventbus
|
||||
self.timeout = timeout
|
||||
try:
|
||||
self._start_time = asyncio.get_event_loop().time()
|
||||
except RuntimeError:
|
||||
# No event loop in sync context, use time.time()
|
||||
import time
|
||||
self._start_time = time.time()
|
||||
self.include_wildcards = include_wildcards
|
||||
|
||||
# Only precompute the first handler for efficient .first() access
|
||||
event_key = event.event_type
|
||||
specific_handlers = self.eventbus.handlers.get(event_key, [])
|
||||
wildcard_handlers = self.eventbus.handlers.get('*', []) if include_wildcards else []
|
||||
|
||||
local_handlers = specific_handlers + wildcard_handlers
|
||||
self._first_handler = local_handlers[0] if local_handlers else None
|
||||
|
||||
# Track EventBuses that have dispatched this event
|
||||
self._seen_eventbus_ids: set[str] = {str(id(eventbus))}
|
||||
|
||||
def __getitem__(self, key: str | Any):
|
||||
"""Get by handler ID, handler name, or handler function"""
|
||||
if callable(key):
|
||||
# Convert handler function to its ID
|
||||
key = str(id(key))
|
||||
if not isinstance(key, str):
|
||||
raise TypeError("EventResults accepts handler IDs (str), handler names (str), or handler functions")
|
||||
|
||||
# First try as handler ID
|
||||
if key in self.event.event_results:
|
||||
return self.event.event_results[key]
|
||||
|
||||
# Then try as handler name - find first matching handler
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if metadata['name'] == key and handler_id in self.event.event_results:
|
||||
return self.event.event_results[handler_id]
|
||||
|
||||
raise KeyError(f"No result found for key: {key}")
|
||||
|
||||
def __await__(self):
|
||||
"""Default to by_handler_id() when awaited directly"""
|
||||
return self.by_handler_id().__await__()
|
||||
|
||||
async def _wait_for_all_handlers(self) -> None:
|
||||
"""Wait for all handlers across all seen EventBuses to complete"""
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while True:
|
||||
# Check if we've timed out
|
||||
if asyncio.get_event_loop().time() - start_time > self.timeout:
|
||||
logger.warning(f"Timeout waiting for all handlers after {self.timeout}s")
|
||||
break
|
||||
|
||||
# Get all eventbus IDs that have registered handlers
|
||||
seen_handler_buses = set()
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if metadata.get('eventbus_id'):
|
||||
seen_handler_buses.add(metadata['eventbus_id'])
|
||||
|
||||
# Check if all seen buses have registered their handlers
|
||||
all_buses_handled = self._seen_eventbus_ids.issubset(seen_handler_buses)
|
||||
|
||||
# Check if all registered handlers have results or errors
|
||||
all_handlers_done = all(
|
||||
handler_id in self.event.event_results or handler_id in self.event.event_errors
|
||||
for handler_id in self.event._handler_metadata
|
||||
)
|
||||
|
||||
if all_buses_handled and all_handlers_done:
|
||||
break
|
||||
|
||||
# Small delay before checking again
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _wait_for_handler(self, handler: Any) -> Any:
|
||||
"""Wait for a specific handler by id"""
|
||||
handler_id = str(id(handler))
|
||||
while handler_id not in self.event.event_results and not self._is_timeout():
|
||||
await asyncio.sleep(0.001)
|
||||
return self.event.event_results.get(handler_id)
|
||||
|
||||
async def _get_all_results(self) -> dict[str, Any]:
|
||||
"""Wait for all handlers to complete and return all results"""
|
||||
await self._wait_for_all_handlers()
|
||||
return self.event.event_results
|
||||
|
||||
async def first(self, default=None) -> Any:
|
||||
"""Get first handler result - only waits for the first handler"""
|
||||
if not self._first_handler:
|
||||
return default
|
||||
return await self._wait_for_handler(self._first_handler) or default
|
||||
|
||||
async def last(self) -> Any:
|
||||
"""Get last handler result - waits for all handlers to complete"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
# Get all handlers from metadata in order
|
||||
handler_items = list(self.event._handler_metadata.items())
|
||||
|
||||
if handler_items:
|
||||
# Get the last handler
|
||||
handler_id = handler_items[-1][0]
|
||||
return self.event.event_results.get(handler_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def by_handler_name(self) -> dict[str, Any]:
|
||||
"""Get results keyed by handler name (last handler with each name wins)"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
results = {}
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if handler_id in self.event.event_results:
|
||||
value = self.event.event_results[handler_id]
|
||||
if value is not None:
|
||||
results[metadata['name']] = value
|
||||
return results
|
||||
|
||||
async def by_handler_id(self) -> dict[str, Any]:
|
||||
"""Get results keyed by handler id string"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
results = {}
|
||||
for handler_id in self.event.event_results:
|
||||
value = self.event.event_results[handler_id]
|
||||
if value is not None:
|
||||
results[handler_id] = value
|
||||
return results
|
||||
|
||||
async def by_eventbus_id(self) -> dict[str, Any]:
|
||||
"""Get results keyed by eventbus id (last handler per eventbus wins)"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
results = {}
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if handler_id in self.event.event_results:
|
||||
eventbus_id = metadata.get('eventbus_id')
|
||||
value = self.event.event_results[handler_id]
|
||||
if value is not None and eventbus_id:
|
||||
results[eventbus_id] = value
|
||||
return results
|
||||
|
||||
async def by_path(self) -> dict[str, Any]:
|
||||
"""Get results keyed by path: eventbus_name#eventbus_id.handler_name"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
results = {}
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if handler_id in self.event.event_results:
|
||||
if metadata.get('eventbus_name') and metadata.get('eventbus_id'):
|
||||
path = f"{metadata['eventbus_name']}#{metadata['eventbus_id']}.{metadata['name']}"
|
||||
value = self.event.event_results[handler_id]
|
||||
if value is not None:
|
||||
results[path] = value
|
||||
return results
|
||||
|
||||
async def values(self) -> list[Any]:
|
||||
"""Get all results as list"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
# Return results in handler registration order
|
||||
results = []
|
||||
for handler_id in self.event._handler_metadata:
|
||||
if handler_id in self.event.event_results:
|
||||
value = self.event.event_results[handler_id]
|
||||
if value is not None:
|
||||
results.append(value)
|
||||
return results
|
||||
|
||||
async def flat_dict(self) -> dict[str, Any]:
|
||||
"""Merge results into single dict"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
merged = {}
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if handler_id in self.event.event_results:
|
||||
result = self.event.event_results[handler_id]
|
||||
if result is not None:
|
||||
if not isinstance(result, dict):
|
||||
handler_name = metadata.get('name', 'unknown')
|
||||
raise TypeError(f"Handler '{handler_name}' returned {type(result).__name__} instead of dict")
|
||||
merged.update(result)
|
||||
return merged
|
||||
|
||||
async def flat_list(self) -> list[Any]:
|
||||
"""Merge results into single list"""
|
||||
await self._wait_for_all_handlers()
|
||||
|
||||
merged = []
|
||||
for handler_id, metadata in self.event._handler_metadata.items():
|
||||
if handler_id in self.event.event_results:
|
||||
result = self.event.event_results[handler_id]
|
||||
if result is not None:
|
||||
if not isinstance(result, list):
|
||||
handler_name = metadata.get('name', 'unknown')
|
||||
raise TypeError(f"Handler '{handler_name}' returned {type(result).__name__} instead of list")
|
||||
merged.extend(result)
|
||||
return merged
|
||||
|
||||
def _is_timeout(self) -> bool:
|
||||
"""Check timeout"""
|
||||
try:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
except RuntimeError:
|
||||
# No event loop in sync context, use time.time()
|
||||
import time
|
||||
current_time = time.time()
|
||||
return current_time - self._start_time > self.timeout
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar, copy_context
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
@@ -11,13 +13,16 @@ import anyio
|
||||
from pydantic import BaseModel
|
||||
from uuid_extensions import uuid7str
|
||||
|
||||
from browser_use.eventbus.models import BaseEvent, UUIDStr
|
||||
from browser_use.eventbus.models import BaseEvent, EventResults, UUIDStr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for event handlers
|
||||
EventHandler = Union[Callable[[BaseEvent], Any], Callable[[BaseEvent], Awaitable[Any]]]
|
||||
|
||||
# Context variable to track the current event being processed
|
||||
_current_event_context: ContextVar[BaseEvent | None] = ContextVar('current_event', default=None)
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
@@ -103,35 +108,69 @@ class EventBus:
|
||||
eventbus.on(TaskStartedEvent, handler) # Event model class
|
||||
eventbus.on('*', handler) # Subscribe to all events
|
||||
eventbus.on('*', other_eventbus.dispatch) # Forward all events to another EventBus
|
||||
|
||||
Note: When forwarding events between buses, all handler results are automatically
|
||||
flattened into the original event's results, so EventResults sees all handlers
|
||||
from all buses as a single flat collection.
|
||||
"""
|
||||
# Allow both sync and async handlers
|
||||
# Determine event key
|
||||
if event_pattern == '*':
|
||||
# Subscribe to all events using '*' as the key
|
||||
self.handlers['*'].append(handler)
|
||||
event_key = '*'
|
||||
elif isinstance(event_pattern, type) and issubclass(event_pattern, BaseModel):
|
||||
# Subscribe by model class
|
||||
self.handlers[event_pattern.__name__].append(handler)
|
||||
event_key = event_pattern.__name__
|
||||
else:
|
||||
# Subscribe by string event type
|
||||
self.handlers[str(event_pattern)].append(handler)
|
||||
event_key = str(event_pattern)
|
||||
|
||||
# Check for duplicate handler names
|
||||
handler_name = handler.__name__
|
||||
existing_names = [h.__name__ for h in self.handlers.get(event_key, [])]
|
||||
|
||||
if handler_name in existing_names:
|
||||
warnings.warn(
|
||||
f"⚠️ Handler '{handler_name}' already registered for event '{event_key}'. "
|
||||
f"This may cause ambiguous results when using name-based access. "
|
||||
f"Consider using unique function names.",
|
||||
UserWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Register handler
|
||||
self.handlers[event_key].append(handler)
|
||||
logger.debug(f"✅ {self} Registered handler {handler_name} for event {event_key}")
|
||||
|
||||
# Auto-start if needed
|
||||
if not self._is_running:
|
||||
self._start()
|
||||
|
||||
def dispatch(self, event: BaseEvent) -> BaseEvent:
|
||||
def dispatch(self, event: BaseEvent, timeout: float = 30.0) -> EventResults:
|
||||
"""
|
||||
Enqueue an event for processing by the handlers. Returns awaitable event object.
|
||||
Enqueue an event for processing and return EventResults for accessing responses.
|
||||
(Auto-starts the EventBus's async _run_loop() if not already running)
|
||||
|
||||
Similar to JS EventListener.dispatchEvent() or eventbus.dispatch() in other languages.
|
||||
Returns EventResults which can be awaited directly or used to access specific results.
|
||||
Access the original event via EventResults.event.
|
||||
"""
|
||||
assert event.event_id, 'Missing event.event_id: UUIDStr = uuid7str()'
|
||||
assert event.event_created_at, 'Missing event.queued_at: datetime = datetime.now(UTC)'
|
||||
assert event.event_type and event.event_type.isidentifier(), 'Missing event.event_type: str'
|
||||
assert event.event_schema and '@' in event.event_schema, 'Missing event.event_schema: str (with @version)'
|
||||
|
||||
# Automatically set parent_event_id from context if not already set
|
||||
if event.parent_event_id is None:
|
||||
current_event = _current_event_context.get()
|
||||
if current_event is not None:
|
||||
event.parent_event_id = current_event.event_id
|
||||
|
||||
# Add this EventBus to the event_path if not already there
|
||||
if self.name not in event.event_path:
|
||||
# preserve identity of the original object instead of creating a new one, so that the original object remains awaitable to get the result
|
||||
# NOT: event = event.model_copy(update={'event_path': event.event_path + [self.name]})
|
||||
event.event_path.append(self.name)
|
||||
|
||||
# Store reference to this EventBus for result() method
|
||||
if not event._eventbus:
|
||||
event._eventbus = self
|
||||
|
||||
|
||||
assert event.event_path, 'Missing event.event_path: list[str] (with at least the origin function name recorded in it)'
|
||||
assert all(entry.isidentifier() for entry in event.event_path), (
|
||||
@@ -150,7 +189,64 @@ class EventBus:
|
||||
except asyncio.QueueFull:
|
||||
logger.error(f'⚠️ {self} Event queue is full! Dropping event {event.event_type}:\n{event.model_dump_json()}')
|
||||
|
||||
return event
|
||||
# Create or update the single EventResults instance
|
||||
if not event.results:
|
||||
event.results = EventResults(event, self, timeout)
|
||||
else:
|
||||
# Just update the seen eventbus IDs
|
||||
event.results._seen_eventbus_ids.add(str(id(self)))
|
||||
|
||||
return event.results
|
||||
|
||||
async def expect(
|
||||
self,
|
||||
event_type: str | type[BaseModel],
|
||||
timeout: float | None = None,
|
||||
predicate: Callable[[BaseEvent], bool] | None = None,
|
||||
) -> BaseEvent:
|
||||
"""
|
||||
Wait for an event matching the given type/pattern with optional predicate filter.
|
||||
|
||||
Args:
|
||||
event_type: The event type string or model class to wait for
|
||||
timeout: Maximum time to wait in seconds (None = wait forever)
|
||||
predicate: Optional filter function that must return True for the event to match
|
||||
|
||||
Returns:
|
||||
The first matching event
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If timeout is reached before a matching event
|
||||
|
||||
Example:
|
||||
# Wait for any response event
|
||||
response = await eventbus.expect('ResponseEvent', timeout=30)
|
||||
|
||||
# Wait for specific response with predicate
|
||||
response = await eventbus.expect(
|
||||
'ResponseEvent',
|
||||
predicate=lambda e: e.request_id == my_request_id,
|
||||
timeout=30
|
||||
)
|
||||
"""
|
||||
future: asyncio.Future[BaseEvent] = asyncio.Future()
|
||||
|
||||
def temporary_handler(event: BaseEvent) -> None:
|
||||
"""Handler that resolves the future when a matching event is found"""
|
||||
if not future.done() and (predicate is None or predicate(event)):
|
||||
future.set_result(event)
|
||||
|
||||
# Register temporary handler
|
||||
self.on(event_type, temporary_handler)
|
||||
|
||||
try:
|
||||
# Wait for the future with optional timeout
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
finally:
|
||||
# Clean up handler
|
||||
event_key = event_type.__name__ if isinstance(event_type, type) else str(event_type)
|
||||
if event_key in self.handlers and temporary_handler in self.handlers[event_key]:
|
||||
self.handlers[event_key].remove(temporary_handler)
|
||||
|
||||
def _start(self) -> None:
|
||||
"""Start the event bus if not already running"""
|
||||
@@ -244,8 +340,8 @@ class EventBus:
|
||||
|
||||
await self._execute_handlers(event, handlers=applicable_handlers)
|
||||
|
||||
# Mark event as completed with empty results dict
|
||||
event.record_results(complete=True)
|
||||
# Mark event as completed
|
||||
event.mark_complete()
|
||||
|
||||
# Persist to WAL if configured
|
||||
if self.wal_path:
|
||||
@@ -266,14 +362,16 @@ class EventBus:
|
||||
# Add wildcard handlers (handlers registered for '*')
|
||||
applicable_handlers.extend(self.handlers.get('*', []))
|
||||
|
||||
# Filter out handlers that would create loops and build name->handler mapping
|
||||
# Filter out handlers that would create loops and build id->handler mapping
|
||||
# Use handler id as key to preserve all handlers even with duplicate names
|
||||
filtered_handlers = {}
|
||||
for handler in applicable_handlers:
|
||||
if self._would_create_loop(event, handler):
|
||||
logger.debug(f'Skipping {handler.__name__} to prevent loop for {event.event_type}')
|
||||
continue
|
||||
else:
|
||||
filtered_handlers[handler.__name__] = handler
|
||||
handler_id = str(id(handler))
|
||||
filtered_handlers[handler_id] = handler
|
||||
|
||||
return filtered_handlers
|
||||
|
||||
@@ -286,46 +384,56 @@ class EventBus:
|
||||
# Execute all handlers in parallel
|
||||
if self.parallel_handlers:
|
||||
handler_tasks = {}
|
||||
for handler_name, handler in applicable_handlers.items():
|
||||
for handler_id, handler in applicable_handlers.items():
|
||||
task = asyncio.create_task(self._execute_sync_or_async_handler(event, handler))
|
||||
handler_tasks[handler_name] = task
|
||||
handler_tasks[handler_id] = (task, handler)
|
||||
|
||||
# Wait for all handlers to complete and record results incrementally
|
||||
for handler_name, task in handler_tasks.items():
|
||||
for handler_id, (task, handler) in handler_tasks.items():
|
||||
try:
|
||||
result = await task
|
||||
event.record_results({handler_name: result}, complete=False)
|
||||
event.record_result(handler, result, self)
|
||||
except Exception as e:
|
||||
event.event_errors[handler_name] = str(e)
|
||||
event.record_error(handler, str(e))
|
||||
logger.error(
|
||||
f'❌ {self} Handler {handler_name} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}'
|
||||
f'❌ {self} Handler {handler.__name__} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}'
|
||||
)
|
||||
else:
|
||||
# otherwise, execute handlers serially, wait until each one completes before moving on to the next
|
||||
for handler_name, handler in applicable_handlers.items():
|
||||
for handler_id, handler in applicable_handlers.items():
|
||||
try:
|
||||
result = await self._execute_sync_or_async_handler(event, handler)
|
||||
event.record_results({handler_name: result}, complete=False)
|
||||
event.record_result(handler, result, self)
|
||||
except Exception as e:
|
||||
event.event_errors[handler_name] = str(e)
|
||||
event.record_error(handler, str(e))
|
||||
logger.error(
|
||||
f'❌ {self} Handler {handler_name} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}'
|
||||
f'❌ {self} Handler {handler.__name__} failed for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}'
|
||||
)
|
||||
|
||||
async def _execute_sync_or_async_handler(self, event: BaseEvent, handler: EventHandler) -> Any:
|
||||
"""Safely execute a single handler"""
|
||||
# Set the current event in context so child events can reference it
|
||||
token = _current_event_context.set(event)
|
||||
try:
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
return await handler(event)
|
||||
else:
|
||||
# Run sync handler in thread pool to avoid blocking
|
||||
# Run sync handler in thread pool with context preserved
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(None, handler, event)
|
||||
ctx = copy_context()
|
||||
# Create a wrapper that preserves the context
|
||||
def context_preserving_wrapper():
|
||||
# The context is already copied, just run the handler
|
||||
return handler(event)
|
||||
return await loop.run_in_executor(None, ctx.run, context_preserving_wrapper)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'❌ {self} Error in handler {handler.__name__} for event {event.event_id}: {type(e).__name__} {e}\n{event.model_dump()}'
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Reset context
|
||||
_current_event_context.reset(token)
|
||||
|
||||
@staticmethod
|
||||
def _would_create_loop(event: BaseEvent, handler: EventHandler) -> bool:
|
||||
|
||||
@@ -24,6 +24,7 @@ from pydantic import Field
|
||||
|
||||
from browser_use.agent.cloud_events import CreateAgentTaskEvent
|
||||
from browser_use.eventbus import BaseEvent, EventBus
|
||||
from browser_use.eventbus.models import EventResults
|
||||
|
||||
|
||||
# Test event models
|
||||
@@ -422,7 +423,7 @@ class TestEdgeCases:
|
||||
tasks = []
|
||||
for i in range(100):
|
||||
event = UserActionEvent(action=f'concurrent_{i}', user_id='u1')
|
||||
# Emit returns the event synchronously, but we need to wait for completion
|
||||
# Emit returns the event syncresultsonously, but we need to wait for completion
|
||||
emitted_event = eventbus.dispatch(event)
|
||||
tasks.append(emitted_event.result())
|
||||
|
||||
@@ -618,9 +619,9 @@ class TestWALPersistence:
|
||||
class TestEventBusHierarchy:
|
||||
"""Test hierarchical EventBus subscription patterns"""
|
||||
|
||||
async def test_three_level_hierarchy_bubbling(self):
|
||||
"""Test that events bubble up through a 3-level hierarchy and event_path is correct"""
|
||||
# Create three EventBus instances in a hierarchy
|
||||
async def test_tresultsee_level_hierarchy_bubbling(self):
|
||||
"""Test that events bubble up tresultsough a 3-level hierarchy and event_path is correct"""
|
||||
# Create tresultsee EventBus instances in a hierarchy
|
||||
parent_bus = EventBus(name='ParentBus')
|
||||
child_bus = EventBus(name='ChildBus')
|
||||
subchild_bus = EventBus(name='SubchildBus')
|
||||
@@ -701,7 +702,7 @@ class TestEventBusHierarchy:
|
||||
|
||||
async def test_circular_subscription_prevention(self):
|
||||
"""Test that circular EventBus subscriptions don't create infinite loops"""
|
||||
# Create three peer EventBus instances
|
||||
# Create tresultsee peer EventBus instances
|
||||
peer1 = EventBus(name='Peer1')
|
||||
peer2 = EventBus(name='Peer2')
|
||||
peer3 = EventBus(name='Peer3')
|
||||
@@ -788,5 +789,675 @@ class TestEventBusHierarchy:
|
||||
await peer3.stop()
|
||||
|
||||
|
||||
class TestExpectMethod:
|
||||
"""Test the expect() method functionality"""
|
||||
|
||||
async def test_expect_basic(self, eventbus):
|
||||
"""Test basic expect functionality"""
|
||||
# Start waiting for an event that hasn't been dispatched yet
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect('UserActionEvent', timeout=1.0)
|
||||
)
|
||||
|
||||
# Give expect time to register handler
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Dispatch the event
|
||||
dispatched = eventbus.dispatch(UserActionEvent(action='login', user_id='user123'))
|
||||
|
||||
# Wait for expect to resolve
|
||||
received = await expect_task
|
||||
|
||||
# Verify we got the right event
|
||||
assert received.event_type == 'UserActionEvent'
|
||||
assert received.action == 'login'
|
||||
assert received.user_id == 'user123'
|
||||
assert received.event_id == dispatched.event_id
|
||||
|
||||
async def test_expect_with_predicate(self, eventbus):
|
||||
"""Test expect with predicate filtering"""
|
||||
# Dispatch some events that don't match
|
||||
eventbus.dispatch(UserActionEvent(action='logout', user_id='user456'))
|
||||
eventbus.dispatch(UserActionEvent(action='login', user_id='user789'))
|
||||
|
||||
# Start expecting with predicate
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect(
|
||||
'UserActionEvent',
|
||||
predicate=lambda e: e.user_id == 'user123',
|
||||
timeout=1.0
|
||||
)
|
||||
)
|
||||
|
||||
# Give expect time to register
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Dispatch more events
|
||||
eventbus.dispatch(UserActionEvent(action='update', user_id='user456'))
|
||||
target_event = eventbus.dispatch(UserActionEvent(action='login', user_id='user123'))
|
||||
eventbus.dispatch(UserActionEvent(action='delete', user_id='user789'))
|
||||
|
||||
# Wait for the matching event
|
||||
received = await expect_task
|
||||
|
||||
# Should get the event matching the predicate
|
||||
assert received.user_id == 'user123'
|
||||
assert received.event_id == target_event.event_id
|
||||
|
||||
async def test_expect_timeout(self, eventbus):
|
||||
"""Test expect timeout behavior"""
|
||||
# Expect an event that will never come
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await eventbus.expect('NonExistentEvent', timeout=0.1)
|
||||
|
||||
async def test_expect_with_model_class(self, eventbus):
|
||||
"""Test expect with model class instead of string"""
|
||||
# Start expecting by model class
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect(SystemEventModel, timeout=1.0)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Dispatch different event types
|
||||
eventbus.dispatch(UserActionEvent(action='test', user_id='u1'))
|
||||
target = eventbus.dispatch(SystemEventModel(event_name='startup', severity='info'))
|
||||
|
||||
# Should receive the SystemEventModel
|
||||
received = await expect_task
|
||||
assert isinstance(received, SystemEventModel)
|
||||
assert received.event_name == 'startup'
|
||||
assert received.event_id == target.event_id
|
||||
|
||||
async def test_multiple_concurrent_expects(self, eventbus):
|
||||
"""Test multiple concurrent expect calls"""
|
||||
# Set up multiple expects for different events
|
||||
expect1 = asyncio.create_task(
|
||||
eventbus.expect(
|
||||
'UserActionEvent',
|
||||
predicate=lambda e: e.action == 'normal',
|
||||
timeout=2.0
|
||||
)
|
||||
)
|
||||
expect2 = asyncio.create_task(
|
||||
eventbus.expect('SystemEventModel', timeout=2.0)
|
||||
)
|
||||
expect3 = asyncio.create_task(
|
||||
eventbus.expect(
|
||||
'UserActionEvent',
|
||||
predicate=lambda e: e.action == 'special',
|
||||
timeout=2.0
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.1) # Give more time for handlers to register
|
||||
|
||||
# Dispatch events
|
||||
e1 = eventbus.dispatch(UserActionEvent(action='normal', user_id='u1'))
|
||||
e2 = eventbus.dispatch(SystemEventModel(event_name='test'))
|
||||
e3 = eventbus.dispatch(UserActionEvent(action='special', user_id='u2'))
|
||||
|
||||
# Wait for all events to be processed
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Wait for all expects
|
||||
r1, r2, r3 = await asyncio.gather(expect1, expect2, expect3)
|
||||
|
||||
# Verify results
|
||||
assert r1.event_id == e1.event_id # Normal UserActionEvent
|
||||
assert r2.event_id == e2.event_id # SystemEventModel
|
||||
assert r3.event_id == e3.event_id # Special UserActionEvent
|
||||
|
||||
async def test_expect_handler_cleanup(self, eventbus):
|
||||
"""Test that temporary handlers are properly cleaned up"""
|
||||
# Check initial handler count
|
||||
initial_handlers = len(eventbus.handlers.get('TestEvent', []))
|
||||
|
||||
# Create an expect that times out
|
||||
try:
|
||||
await eventbus.expect('TestEvent', timeout=0.1)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Handler should be cleaned up
|
||||
assert len(eventbus.handlers.get('TestEvent', [])) == initial_handlers
|
||||
|
||||
# Create an expect that succeeds
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect('TestEvent2', timeout=1.0)
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
eventbus.dispatch(BaseEvent(event_type='TestEvent2'))
|
||||
await expect_task
|
||||
|
||||
# Handler should be cleaned up
|
||||
assert len(eventbus.handlers.get('TestEvent2', [])) == 0
|
||||
|
||||
async def test_expect_receives_completed_event(self, eventbus):
|
||||
"""Test that expect receives events after they're fully processed"""
|
||||
processing_complete = False
|
||||
|
||||
async def slow_handler(event: BaseEvent) -> str:
|
||||
await asyncio.sleep(0.1)
|
||||
nonlocal processing_complete
|
||||
processing_complete = True
|
||||
return 'done'
|
||||
|
||||
# Register a slow handler
|
||||
eventbus.on('SlowEvent', slow_handler)
|
||||
|
||||
# Start expecting
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect('SlowEvent', timeout=1.0)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Dispatch event
|
||||
eventbus.dispatch(BaseEvent(event_type='SlowEvent'))
|
||||
|
||||
# Wait for expect
|
||||
received = await expect_task
|
||||
|
||||
# At this point, the slow handler should have run
|
||||
# but we receive the event as soon as it matches
|
||||
assert received.event_type == 'SlowEvent'
|
||||
# The event might not be fully completed yet since expect
|
||||
# triggers as soon as the event is processed by its handler
|
||||
|
||||
async def test_expect_with_complex_predicate(self, eventbus):
|
||||
"""Test expect with complex predicate logic"""
|
||||
events_seen = []
|
||||
|
||||
def complex_predicate(event: UserActionEvent) -> bool:
|
||||
events_seen.append(event.action)
|
||||
# Only match after seeing at least 3 events and action is 'target'
|
||||
return len(events_seen) >= 3 and event.action == 'target'
|
||||
|
||||
expect_task = asyncio.create_task(
|
||||
eventbus.expect(
|
||||
'UserActionEvent',
|
||||
predicate=complex_predicate,
|
||||
timeout=1.0
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Dispatch events
|
||||
eventbus.dispatch(UserActionEvent(action='first', user_id='u1'))
|
||||
eventbus.dispatch(UserActionEvent(action='second', user_id='u2'))
|
||||
eventbus.dispatch(UserActionEvent(action='target', user_id='u3')) # Won't match yet
|
||||
eventbus.dispatch(UserActionEvent(action='target', user_id='u4')) # This should match
|
||||
|
||||
received = await expect_task
|
||||
|
||||
assert received.user_id == 'u4'
|
||||
assert len(events_seen) == 4
|
||||
|
||||
async def test_expect_in_sync_context(self, mock_agent):
|
||||
"""Test that expect can be used from sync code that later awaits"""
|
||||
bus = EventBus()
|
||||
|
||||
# This simulates calling expect from sync code
|
||||
expect_coroutine = bus.expect('SyncEvent', timeout=1.0)
|
||||
|
||||
# Dispatch event
|
||||
bus.dispatch(BaseEvent(event_type='SyncEvent'))
|
||||
|
||||
# Later await the coroutine
|
||||
result = await expect_coroutine
|
||||
assert result.event_type == 'SyncEvent'
|
||||
|
||||
await bus.stop()
|
||||
|
||||
|
||||
class TestEventResults:
|
||||
"""Test the new EventResults functionality"""
|
||||
|
||||
async def test_dispatch_returns_event_results(self, eventbus):
|
||||
"""Test that dispatch now returns EventResults"""
|
||||
# Register a specific handler
|
||||
async def test_handler(event):
|
||||
return 'test_result'
|
||||
|
||||
eventbus.on('UserActionEvent', test_handler)
|
||||
|
||||
result = eventbus.dispatch(UserActionEvent(action='test', user_id='u1'))
|
||||
assert isinstance(result, EventResults)
|
||||
|
||||
# Can be awaited directly (returns by_handler_id dict)
|
||||
all_results = await result
|
||||
assert isinstance(all_results, dict)
|
||||
# Should contain both test_handler and default log handler results
|
||||
assert len(all_results) == 2
|
||||
assert 'test_result' in all_results.values()
|
||||
assert 'logged' in all_results.values()
|
||||
|
||||
# Test with no specific handlers (only wildcard)
|
||||
result_no_handlers = eventbus.dispatch(BaseEvent(event_type='NoHandlersEvent'))
|
||||
assert await result_no_handlers.first() is None # No specific handlers
|
||||
|
||||
# Test including wildcards explicitly
|
||||
result_with_wildcards = EventResults(result_no_handlers.event, eventbus, include_wildcards=True)
|
||||
assert await result_with_wildcards.first() == 'logged' # Wildcard handler result
|
||||
|
||||
async def test_event_results_indexing(self, eventbus):
|
||||
"""Test indexing by handler name and ID"""
|
||||
order = []
|
||||
|
||||
async def handler1(event):
|
||||
order.append(1)
|
||||
return 'first'
|
||||
|
||||
async def handler2(event):
|
||||
order.append(2)
|
||||
return 'second'
|
||||
|
||||
async def handler3(event):
|
||||
order.append(3)
|
||||
return 'third'
|
||||
|
||||
eventbus.on('TestEvent', handler1)
|
||||
eventbus.on('TestEvent', handler2)
|
||||
eventbus.on('TestEvent', handler3)
|
||||
|
||||
# Test indexing
|
||||
hr = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
|
||||
# Wait for all handlers to complete
|
||||
await hr
|
||||
|
||||
# Index by handler name
|
||||
assert hr['handler1'] == 'first'
|
||||
assert hr['handler2'] == 'second'
|
||||
assert hr['handler3'] == 'third'
|
||||
|
||||
# Index by handler function
|
||||
assert hr[handler1] == 'first'
|
||||
assert hr[handler2] == 'second'
|
||||
assert hr[handler3] == 'third'
|
||||
|
||||
async def test_first_last_methods(self, eventbus):
|
||||
"""Test first() and last() methods"""
|
||||
async def early_handler(event):
|
||||
return 'early'
|
||||
|
||||
async def late_handler(event):
|
||||
await asyncio.sleep(0.01)
|
||||
return 'late'
|
||||
|
||||
eventbus.on('TestEvent', early_handler)
|
||||
eventbus.on('TestEvent', late_handler)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
|
||||
# first() returns first handler result
|
||||
assert await results.first() == 'early'
|
||||
|
||||
# last() returns last handler result
|
||||
assert await results.last() == 'late'
|
||||
|
||||
# With empty handlers
|
||||
eventbus.handlers['EmptyEvent'] = []
|
||||
results_empty = eventbus.dispatch(BaseEvent(event_type='EmptyEvent'))
|
||||
assert await results_empty.first(default='none') == 'none'
|
||||
assert await results_empty.last() is None
|
||||
|
||||
async def test_by_handler_name(self, eventbus):
|
||||
"""Test by_handler_name() with duplicate names"""
|
||||
async def process_data(event):
|
||||
return 'version1'
|
||||
|
||||
async def process_data(event): # Same name!
|
||||
return 'version2'
|
||||
|
||||
async def unique_handler(event):
|
||||
return 'unique'
|
||||
|
||||
# Should get warning about duplicate name
|
||||
with pytest.warns(UserWarning, match='already registered'):
|
||||
eventbus.on('TestEvent', process_data)
|
||||
eventbus.on('TestEvent', process_data)
|
||||
eventbus.on('TestEvent', unique_handler)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
results = await results.by_handler_name()
|
||||
|
||||
# Last handler with same name wins
|
||||
assert results['process_data'] == 'version2'
|
||||
assert results['unique_handler'] == 'unique'
|
||||
# Default log handler is included as a wildcard handler
|
||||
assert '_default_log_handler' in results
|
||||
assert len(results) == 3 # 2 test handlers + 1 default log handler
|
||||
|
||||
async def test_by_handler_id(self, eventbus):
|
||||
"""Test by_handler_id() returns all handlers uniquely"""
|
||||
async def handler1(event):
|
||||
return 'v1'
|
||||
|
||||
async def handler2(event):
|
||||
return 'v2'
|
||||
|
||||
# Give them the same name for the test
|
||||
handler1.__name__ = 'handler'
|
||||
handler2.__name__ = 'handler'
|
||||
|
||||
eventbus.on('TestEvent', handler1)
|
||||
eventbus.on('TestEvent', handler2)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
results = await results.by_handler_id()
|
||||
|
||||
# All handlers present with unique IDs even with same name
|
||||
assert len(results) == 2
|
||||
assert 'v1' in results.values()
|
||||
assert 'v2' in results.values()
|
||||
|
||||
async def test_flat_dict(self, eventbus):
|
||||
"""Test flat_dict() merging"""
|
||||
async def config_base(event):
|
||||
return {'debug': False, 'port': 8080, 'name': 'base'}
|
||||
|
||||
async def config_override(event):
|
||||
return {'debug': True, 'timeout': 30, 'name': 'override'}
|
||||
|
||||
eventbus.on('GetConfig', config_base)
|
||||
eventbus.on('GetConfig', config_override)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='GetConfig'))
|
||||
merged = await results.flat_dict()
|
||||
|
||||
# Later handlers override earlier ones
|
||||
assert merged == {
|
||||
'debug': True, # Overridden
|
||||
'port': 8080, # From base
|
||||
'timeout': 30, # From override
|
||||
'name': 'override' # Overridden
|
||||
}
|
||||
|
||||
# Test type error
|
||||
async def bad_handler(event):
|
||||
return 'not a dict'
|
||||
|
||||
eventbus.on('BadConfig', bad_handler)
|
||||
results_bad = eventbus.dispatch(BaseEvent(event_type='BadConfig'))
|
||||
|
||||
with pytest.raises(TypeError, match='returned str instead of dict'):
|
||||
await results_bad.flat_dict()
|
||||
|
||||
async def test_flat_list(self, eventbus):
|
||||
"""Test flat_list() concatenation"""
|
||||
async def errors1(event):
|
||||
return ['error1', 'error2']
|
||||
|
||||
async def errors2(event):
|
||||
return ['error3']
|
||||
|
||||
async def errors3(event):
|
||||
return ['error4', 'error5']
|
||||
|
||||
eventbus.on('GetErrors', errors1)
|
||||
eventbus.on('GetErrors', errors2)
|
||||
eventbus.on('GetErrors', errors3)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='GetErrors'))
|
||||
all_errors = await results.flat_list()
|
||||
|
||||
assert all_errors == ['error1', 'error2', 'error3', 'error4', 'error5']
|
||||
|
||||
|
||||
# Test type error
|
||||
async def bad_handler(event):
|
||||
return {'not': 'a list'}
|
||||
|
||||
eventbus.on('BadList', bad_handler)
|
||||
results_bad = eventbus.dispatch(BaseEvent(event_type='BadList'))
|
||||
|
||||
with pytest.raises(TypeError, match='returned dict instead of list'):
|
||||
await results_bad.flat_list()
|
||||
|
||||
async def test_by_handler_name_access(self, eventbus):
|
||||
"""Test by_handler_name() method for name-based access"""
|
||||
async def handler_a(event):
|
||||
return 'result_a'
|
||||
|
||||
async def handler_b(event):
|
||||
return 'result_b'
|
||||
|
||||
eventbus.on('TestEvent', handler_a)
|
||||
eventbus.on('TestEvent', handler_b)
|
||||
|
||||
results = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
|
||||
by_name = await results.by_handler_name()
|
||||
assert by_name.get('handler_a') == 'result_a'
|
||||
assert by_name.get('handler_b') == 'result_b'
|
||||
assert by_name.get('nonexistent', 'fallback') == 'fallback'
|
||||
|
||||
async def test_string_indexing(self, eventbus):
|
||||
"""Test string indexing for handler access"""
|
||||
async def my_handler(event):
|
||||
return 'my_result'
|
||||
|
||||
eventbus.on('TestEvent', my_handler)
|
||||
results = eventbus.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
|
||||
# Wait for handlers to complete
|
||||
await results
|
||||
|
||||
# String indexing by handler name
|
||||
assert results['my_handler'] == 'my_result'
|
||||
|
||||
# Missing key raises KeyError
|
||||
with pytest.raises(KeyError, match='No result found for key: missing'):
|
||||
results['missing']
|
||||
|
||||
|
||||
class TestEventBusForwarding:
|
||||
"""Test event forwarding between buses with new EventResults"""
|
||||
|
||||
async def test_forwarding_flattens_results(self):
|
||||
"""Test that forwarding events between buses flattens all results"""
|
||||
bus1 = EventBus(name='Bus1')
|
||||
bus2 = EventBus(name='Bus2')
|
||||
bus3 = EventBus(name='Bus3')
|
||||
|
||||
results = []
|
||||
|
||||
async def bus1_handler(event):
|
||||
results.append('bus1')
|
||||
return 'from_bus1'
|
||||
|
||||
async def bus2_handler(event):
|
||||
results.append('bus2')
|
||||
return 'from_bus2'
|
||||
|
||||
async def bus3_handler(event):
|
||||
results.append('bus3')
|
||||
return 'from_bus3'
|
||||
|
||||
# Register handlers
|
||||
bus1.on('TestEvent', bus1_handler)
|
||||
bus2.on('TestEvent', bus2_handler)
|
||||
bus3.on('TestEvent', bus3_handler)
|
||||
|
||||
# Set up forwarding chain
|
||||
bus1.on('*', bus2.dispatch)
|
||||
bus2.on('*', bus3.dispatch)
|
||||
|
||||
try:
|
||||
# Dispatch from bus1
|
||||
results = bus1.dispatch(BaseEvent(event_type='TestEvent'))
|
||||
|
||||
# Wait for all buses to complete processing
|
||||
await bus1.wait_until_idle()
|
||||
await bus2.wait_until_idle()
|
||||
await bus3.wait_until_idle()
|
||||
|
||||
# All handlers from all buses should be visible
|
||||
all_results = await results.by_handler_name()
|
||||
assert 'bus1_handler' in all_results
|
||||
assert 'bus2_handler' in all_results
|
||||
assert 'bus3_handler' in all_results
|
||||
|
||||
# Results should be flattened
|
||||
assert all_results['bus1_handler'] == 'from_bus1'
|
||||
assert all_results['bus2_handler'] == 'from_bus2'
|
||||
assert all_results['bus3_handler'] == 'from_bus3'
|
||||
|
||||
# Check execution order
|
||||
assert results == ['bus1', 'bus2', 'bus3']
|
||||
|
||||
finally:
|
||||
await bus1.stop()
|
||||
await bus2.stop()
|
||||
await bus3.stop()
|
||||
|
||||
async def test_by_eventbus_id_and_path(self):
|
||||
"""Test by_eventbus_id() and by_path() with forwarding"""
|
||||
bus1 = EventBus(name='MainBus')
|
||||
bus2 = EventBus(name='PluginBus')
|
||||
|
||||
async def main_handler(event):
|
||||
return 'main_result'
|
||||
|
||||
async def plugin_handler1(event):
|
||||
return 'plugin_result1'
|
||||
|
||||
async def plugin_handler2(event):
|
||||
return 'plugin_result2'
|
||||
|
||||
bus1.on('DataEvent', main_handler)
|
||||
bus2.on('DataEvent', plugin_handler1)
|
||||
bus2.on('DataEvent', plugin_handler2)
|
||||
|
||||
# Forward from bus1 to bus2
|
||||
bus1.on('*', bus2.dispatch)
|
||||
|
||||
try:
|
||||
results = bus1.dispatch(BaseEvent(event_type='DataEvent'))
|
||||
|
||||
# Test by_eventbus_id
|
||||
by_bus = await results.by_eventbus_id()
|
||||
assert len(by_bus) == 2 # One per bus
|
||||
assert str(id(bus1)) in by_bus
|
||||
assert str(id(bus2)) in by_bus
|
||||
# Last handler per bus wins
|
||||
assert by_bus[str(id(bus2))] == 'plugin_result2'
|
||||
|
||||
# Test by_path
|
||||
by_path = await results.by_path()
|
||||
assert f'MainBus#{id(bus1)}.main_handler' in by_path
|
||||
assert f'PluginBus#{id(bus2)}.plugin_handler1' in by_path
|
||||
assert f'PluginBus#{id(bus2)}.plugin_handler2' in by_path
|
||||
|
||||
finally:
|
||||
await bus1.stop()
|
||||
await bus2.stop()
|
||||
|
||||
|
||||
class TestComplexIntegration:
|
||||
"""Complex integration test with all features"""
|
||||
|
||||
async def test_complex_multi_bus_scenario(self):
|
||||
"""Test complex scenario with multiple buses, duplicate names, and all query methods"""
|
||||
# Create a hierarchy of buses
|
||||
app_bus = EventBus(name='AppBus')
|
||||
auth_bus = EventBus(name='AuthBus')
|
||||
data_bus = EventBus(name='DataBus')
|
||||
|
||||
# Handlers with conflicting names
|
||||
async def validate(event):
|
||||
"""App validation"""
|
||||
return {'app_valid': True, 'timestamp': 1000}
|
||||
|
||||
async def validate(event):
|
||||
"""Auth validation"""
|
||||
return {'auth_valid': True, 'user': 'alice'}
|
||||
|
||||
async def validate(event):
|
||||
"""Data validation"""
|
||||
return {'data_valid': True, 'schema': 'v2'}
|
||||
|
||||
async def process(event):
|
||||
"""Auth processing"""
|
||||
return ['auth_log_1', 'auth_log_2']
|
||||
|
||||
async def process(event):
|
||||
"""Data processing"""
|
||||
return ['data_log_1', 'data_log_2', 'data_log_3']
|
||||
|
||||
# Register handlers with same names on different buses
|
||||
app_bus.on('ValidationRequest', validate)
|
||||
auth_bus.on('ValidationRequest', validate)
|
||||
auth_bus.on('ValidationRequest', process) # Different return type!
|
||||
data_bus.on('ValidationRequest', validate)
|
||||
data_bus.on('ValidationRequest', process)
|
||||
|
||||
# Set up forwarding
|
||||
app_bus.on('*', auth_bus.dispatch)
|
||||
auth_bus.on('*', data_bus.dispatch)
|
||||
|
||||
try:
|
||||
# Dispatch event
|
||||
results = app_bus.dispatch(BaseEvent(event_type='ValidationRequest'))
|
||||
|
||||
# Test all access methods
|
||||
|
||||
# 1. Direct await (first result)
|
||||
first = await results
|
||||
assert first == 'logged' # Default logger
|
||||
|
||||
# 2. Slicing
|
||||
validation_results = await results[1:4].values()
|
||||
assert len(validation_results) == 3
|
||||
|
||||
# 3. by_handler_name - duplicates overwrite
|
||||
by_name = await results.by_handler_name()
|
||||
assert by_name['validate'] == {'data_valid': True, 'schema': 'v2'} # Last wins
|
||||
assert by_name['process'] == ['data_log_1', 'data_log_2', 'data_log_3'] # Last wins
|
||||
|
||||
# 4. by_handler_id - all unique
|
||||
by_id = await results.by_handler_id()
|
||||
assert len(by_id) >= 5 # At least 5 handlers
|
||||
|
||||
# 5. flat_dict - only dict results
|
||||
dict_handlers = [h for h in by_id.values() if isinstance(h, dict)]
|
||||
assert len(dict_handlers) >= 3 # 3 validate handlers
|
||||
|
||||
# 6. flat_list - only list results
|
||||
list_handlers = [h for h in by_id.values() if isinstance(h, list)]
|
||||
assert len(list_handlers) >= 2 # 2 process handlers
|
||||
|
||||
# 7. Mixed flat operations with slicing
|
||||
# Skip default logger and get only validation dicts
|
||||
with pytest.raises(TypeError, match='instead of dict'):
|
||||
# This should fail because we're mixing dicts and lists
|
||||
await results[1:].flat_dict()
|
||||
|
||||
# 8. by_eventbus_id
|
||||
by_bus = await results.by_eventbus_id()
|
||||
assert len(by_bus) == 3 # One result per bus
|
||||
|
||||
# 9. by_path for full traceability
|
||||
by_path = await results.by_path()
|
||||
paths = list(by_path.keys())
|
||||
assert any('AppBus#' in p and '.validate' in p for p in paths)
|
||||
assert any('AuthBus#' in p and '.validate' in p for p in paths)
|
||||
assert any('AuthBus#' in p and '.process' in p for p in paths)
|
||||
assert any('DataBus#' in p and '.validate' in p for p in paths)
|
||||
assert any('DataBus#' in p and '.process' in p for p in paths)
|
||||
|
||||
# 10. Test handlers property
|
||||
assert 'validate' in results.handlers
|
||||
assert 'process' in results.handlers
|
||||
|
||||
finally:
|
||||
await app_bus.stop()
|
||||
await auth_bus.stop()
|
||||
await data_bus.stop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
|
||||
285
tests/ci/test_parent_event_tracking.py
Normal file
285
tests/ci/test_parent_event_tracking.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Test parent event tracking functionality in EventBus.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from browser_use.eventbus import BaseEvent, EventBus
|
||||
|
||||
|
||||
class ParentEvent(BaseEvent):
|
||||
"""Parent event that triggers child events"""
|
||||
event_type: str = 'ParentEvent'
|
||||
message: str
|
||||
|
||||
|
||||
class ChildEvent(BaseEvent):
|
||||
"""Child event triggered by parent"""
|
||||
event_type: str = 'ChildEvent'
|
||||
data: str
|
||||
|
||||
|
||||
class GrandchildEvent(BaseEvent):
|
||||
"""Grandchild event triggered by child"""
|
||||
event_type: str = 'GrandchildEvent'
|
||||
value: int
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def eventbus():
|
||||
"""Create an event bus for testing"""
|
||||
bus = EventBus(name='TestBus')
|
||||
yield bus
|
||||
await bus.stop()
|
||||
|
||||
|
||||
class TestParentEventTracking:
|
||||
"""Test automatic parent event ID tracking"""
|
||||
|
||||
async def test_basic_parent_tracking(self, eventbus):
|
||||
"""Test that child events automatically get parent_event_id"""
|
||||
child_events = []
|
||||
|
||||
async def parent_handler(event: ParentEvent) -> str:
|
||||
# Handler that dispatches a child event
|
||||
child = ChildEvent(data=f"child_of_{event.message}")
|
||||
eventbus.dispatch(child)
|
||||
child_events.append(child)
|
||||
return 'parent_handled'
|
||||
|
||||
eventbus.on('ParentEvent', parent_handler)
|
||||
|
||||
# Dispatch parent event
|
||||
parent = ParentEvent(message='test_parent')
|
||||
parent_result = eventbus.dispatch(parent)
|
||||
|
||||
# Wait for processing
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Verify parent processed
|
||||
assert await parent_result.get('parent_handler') == 'parent_handled'
|
||||
|
||||
# Verify child has parent_event_id set
|
||||
assert len(child_events) == 1
|
||||
child = child_events[0]
|
||||
assert child.parent_event_id == parent.event_id
|
||||
|
||||
async def test_multi_level_parent_tracking(self, eventbus):
|
||||
"""Test parent tracking across multiple levels"""
|
||||
events_by_level = {'parent': None, 'child': None, 'grandchild': None}
|
||||
|
||||
async def parent_handler(event: ParentEvent) -> str:
|
||||
events_by_level['parent'] = event
|
||||
child = ChildEvent(data='child_data')
|
||||
eventbus.dispatch(child)
|
||||
return 'parent'
|
||||
|
||||
async def child_handler(event: ChildEvent) -> str:
|
||||
events_by_level['child'] = event
|
||||
grandchild = GrandchildEvent(value=42)
|
||||
eventbus.dispatch(grandchild)
|
||||
return 'child'
|
||||
|
||||
async def grandchild_handler(event: GrandchildEvent) -> str:
|
||||
events_by_level['grandchild'] = event
|
||||
return 'grandchild'
|
||||
|
||||
# Register handlers
|
||||
eventbus.on('ParentEvent', parent_handler)
|
||||
eventbus.on('ChildEvent', child_handler)
|
||||
eventbus.on('GrandchildEvent', grandchild_handler)
|
||||
|
||||
# Start the chain
|
||||
parent = ParentEvent(message='root')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
# Wait for all processing
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Verify the parent chain
|
||||
assert events_by_level['parent'].parent_event_id is None # Root has no parent
|
||||
assert events_by_level['child'].parent_event_id == parent.event_id
|
||||
assert events_by_level['grandchild'].parent_event_id == events_by_level['child'].event_id
|
||||
|
||||
async def test_multiple_children_same_parent(self, eventbus):
|
||||
"""Test multiple child events from same parent"""
|
||||
child_events = []
|
||||
|
||||
async def parent_handler(event: ParentEvent) -> str:
|
||||
# Dispatch multiple children
|
||||
for i in range(3):
|
||||
child = ChildEvent(data=f"child_{i}")
|
||||
eventbus.dispatch(child)
|
||||
child_events.append(child)
|
||||
return 'spawned_children'
|
||||
|
||||
eventbus.on('ParentEvent', parent_handler)
|
||||
|
||||
# Dispatch parent
|
||||
parent = ParentEvent(message='multi_child_parent')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# All children should have same parent
|
||||
assert len(child_events) == 3
|
||||
for child in child_events:
|
||||
assert child.parent_event_id == parent.event_id
|
||||
|
||||
async def test_parallel_handlers_parent_tracking(self, eventbus):
|
||||
"""Test parent tracking with parallel handlers"""
|
||||
events_from_handlers = {'h1': [], 'h2': []}
|
||||
|
||||
async def handler1(event: ParentEvent) -> str:
|
||||
await asyncio.sleep(0.01) # Simulate work
|
||||
child = ChildEvent(data='from_h1')
|
||||
eventbus.dispatch(child)
|
||||
events_from_handlers['h1'].append(child)
|
||||
return 'h1'
|
||||
|
||||
async def handler2(event: ParentEvent) -> str:
|
||||
await asyncio.sleep(0.02) # Different timing
|
||||
child = ChildEvent(data='from_h2')
|
||||
eventbus.dispatch(child)
|
||||
events_from_handlers['h2'].append(child)
|
||||
return 'h2'
|
||||
|
||||
# Both handlers respond to same event
|
||||
eventbus.on('ParentEvent', handler1)
|
||||
eventbus.on('ParentEvent', handler2)
|
||||
|
||||
# Dispatch parent
|
||||
parent = ParentEvent(message='parallel_test')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Both children should have same parent despite parallel execution
|
||||
assert len(events_from_handlers['h1']) == 1
|
||||
assert len(events_from_handlers['h2']) == 1
|
||||
assert events_from_handlers['h1'][0].parent_event_id == parent.event_id
|
||||
assert events_from_handlers['h2'][0].parent_event_id == parent.event_id
|
||||
|
||||
async def test_explicit_parent_not_overridden(self, eventbus):
|
||||
"""Test that explicitly set parent_event_id is not overridden"""
|
||||
captured_child = None
|
||||
|
||||
async def parent_handler(event: ParentEvent) -> str:
|
||||
nonlocal captured_child
|
||||
# Create child with explicit parent_event_id
|
||||
child = ChildEvent(data='explicit', parent_event_id='explicit_parent_id')
|
||||
eventbus.dispatch(child)
|
||||
captured_child = child
|
||||
return 'dispatched'
|
||||
|
||||
eventbus.on('ParentEvent', parent_handler)
|
||||
|
||||
parent = ParentEvent(message='test')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Explicit parent_event_id should be preserved
|
||||
assert captured_child is not None
|
||||
assert captured_child.parent_event_id == 'explicit_parent_id'
|
||||
assert captured_child.parent_event_id != parent.event_id
|
||||
|
||||
async def test_cross_eventbus_parent_tracking(self):
|
||||
"""Test parent tracking across multiple EventBuses"""
|
||||
bus1 = EventBus(name='Bus1')
|
||||
bus2 = EventBus(name='Bus2')
|
||||
|
||||
captured_events = []
|
||||
|
||||
async def bus1_handler(event: ParentEvent) -> str:
|
||||
# Dispatch child to bus2
|
||||
child = ChildEvent(data='cross_bus_child')
|
||||
bus2.dispatch(child)
|
||||
captured_events.append(('bus1', event, child))
|
||||
return 'bus1_handled'
|
||||
|
||||
async def bus2_handler(event: ChildEvent) -> str:
|
||||
captured_events.append(('bus2', event))
|
||||
return 'bus2_handled'
|
||||
|
||||
bus1.on('ParentEvent', bus1_handler)
|
||||
bus2.on('ChildEvent', bus2_handler)
|
||||
|
||||
try:
|
||||
# Dispatch parent to bus1
|
||||
parent = ParentEvent(message='cross_bus_test')
|
||||
bus1.dispatch(parent)
|
||||
|
||||
await bus1.wait_until_idle()
|
||||
await bus2.wait_until_idle()
|
||||
|
||||
# Verify parent tracking works across buses
|
||||
assert len(captured_events) == 2
|
||||
_, parent_event, child_event = captured_events[0]
|
||||
_, received_child = captured_events[1]
|
||||
|
||||
assert child_event.parent_event_id == parent.event_id
|
||||
assert received_child.parent_event_id == parent.event_id
|
||||
|
||||
finally:
|
||||
await bus1.stop()
|
||||
await bus2.stop()
|
||||
|
||||
async def test_sync_handler_parent_tracking(self, eventbus):
|
||||
"""Test parent tracking works with sync handlers"""
|
||||
child_events = []
|
||||
|
||||
def sync_parent_handler(event: ParentEvent) -> str:
|
||||
# Sync handler that dispatches child
|
||||
child = ChildEvent(data='from_sync')
|
||||
eventbus.dispatch(child)
|
||||
child_events.append(child)
|
||||
return 'sync_handled'
|
||||
|
||||
eventbus.on('ParentEvent', sync_parent_handler)
|
||||
|
||||
parent = ParentEvent(message='sync_test')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Parent tracking should work even with sync handlers
|
||||
assert len(child_events) == 1
|
||||
assert child_events[0].parent_event_id == parent.event_id
|
||||
|
||||
async def test_error_handler_parent_tracking(self, eventbus):
|
||||
"""Test parent tracking when handler errors occur"""
|
||||
child_events = []
|
||||
|
||||
async def failing_handler(event: ParentEvent) -> str:
|
||||
# Dispatch child before failing
|
||||
child = ChildEvent(data='before_error')
|
||||
eventbus.dispatch(child)
|
||||
child_events.append(child)
|
||||
raise ValueError("Handler failed")
|
||||
|
||||
async def success_handler(event: ParentEvent) -> str:
|
||||
# This should still run
|
||||
child = ChildEvent(data='after_error')
|
||||
eventbus.dispatch(child)
|
||||
child_events.append(child)
|
||||
return 'success'
|
||||
|
||||
eventbus.on('ParentEvent', failing_handler)
|
||||
eventbus.on('ParentEvent', success_handler)
|
||||
|
||||
parent = ParentEvent(message='error_test')
|
||||
eventbus.dispatch(parent)
|
||||
|
||||
await eventbus.wait_until_idle()
|
||||
|
||||
# Both children should have parent_event_id despite error
|
||||
assert len(child_events) == 2
|
||||
for child in child_events:
|
||||
assert child.parent_event_id == parent.event_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
Reference in New Issue
Block a user