Files
mistral-vibe/vibe/core/agent_loop.py
Mathias Gesbert ec7f3b25ea v2.2.0 (#395)
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai>
Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai>
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
2026-02-17 16:23:28 +01:00

1076 lines
38 KiB
Python

from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Callable
from enum import StrEnum, auto
from http import HTTPStatus
import json
from pathlib import Path
from threading import Thread
import time
from typing import TYPE_CHECKING, Any, Literal, cast
from uuid import uuid4
from pydantic import BaseModel
from vibe.core.agents.manager import AgentManager
from vibe.core.agents.models import AgentProfile, BuiltinAgentName
from vibe.core.config import Backend, ProviderConfig, VibeConfig
from vibe.core.llm.backend.factory import BACKEND_FACTORY
from vibe.core.llm.exceptions import BackendError
from vibe.core.llm.format import (
APIToolFormatHandler,
FailedToolCall,
ResolvedMessage,
ResolvedToolCall,
)
from vibe.core.llm.types import BackendLike
from vibe.core.middleware import (
AutoCompactMiddleware,
ContextWarningMiddleware,
ConversationContext,
MiddlewareAction,
MiddlewarePipeline,
MiddlewareResult,
PlanAgentMiddleware,
PriceLimitMiddleware,
ResetReason,
TurnLimitMiddleware,
)
from vibe.core.prompts import UtilityPrompt
from vibe.core.session.session_logger import SessionLogger
from vibe.core.session.session_migration import migrate_sessions_entrypoint
from vibe.core.skills.manager import SkillManager
from vibe.core.system_prompt import get_universal_system_prompt
from vibe.core.telemetry.send import TelemetryClient
from vibe.core.tools.base import (
BaseTool,
BaseToolConfig,
InvokeContext,
ToolError,
ToolPermission,
ToolPermissionError,
)
from vibe.core.tools.manager import ToolManager
from vibe.core.trusted_folders import has_agents_md_file
from vibe.core.types import (
AgentStats,
ApprovalCallback,
ApprovalResponse,
AssistantEvent,
AsyncApprovalCallback,
BaseEvent,
CompactEndEvent,
CompactStartEvent,
LLMChunk,
LLMMessage,
LLMUsage,
RateLimitError,
ReasoningEvent,
Role,
SyncApprovalCallback,
ToolCallEvent,
ToolResultEvent,
ToolStreamEvent,
UserInputCallback,
UserMessageEvent,
)
from vibe.core.utils import (
TOOL_ERROR_TAG,
VIBE_STOP_EVENT_TAG,
CancellationReason,
get_user_agent,
get_user_cancellation_message,
is_user_cancellation_event,
)
try:
from vibe.core.teleport.teleport import TeleportService as _TeleportService
_TELEPORT_AVAILABLE = True
except ImportError:
_TELEPORT_AVAILABLE = False
_TeleportService = None
if TYPE_CHECKING:
from vibe.core.teleport.nuage import TeleportSession
from vibe.core.teleport.teleport import TeleportService
from vibe.core.teleport.types import TeleportPushResponseEvent, TeleportYieldEvent
class ToolExecutionResponse(StrEnum):
SKIP = auto()
EXECUTE = auto()
class ToolDecision(BaseModel):
verdict: ToolExecutionResponse
approval_type: ToolPermission
feedback: str | None = None
class AgentLoopError(Exception):
"""Base exception for AgentLoop errors."""
class AgentLoopStateError(AgentLoopError):
"""Raised when agent loop is in an invalid state."""
class AgentLoopLLMResponseError(AgentLoopError):
"""Raised when LLM response is malformed or missing expected data."""
class TeleportError(AgentLoopError):
"""Raised when teleport to Vibe Nuage fails."""
def _should_raise_rate_limit_error(e: Exception) -> bool:
return isinstance(e, BackendError) and e.status == HTTPStatus.TOO_MANY_REQUESTS
class AgentLoop:
def __init__(
self,
config: VibeConfig,
agent_name: str = BuiltinAgentName.DEFAULT,
message_observer: Callable[[LLMMessage], None] | None = None,
max_turns: int | None = None,
max_price: float | None = None,
backend: BackendLike | None = None,
enable_streaming: bool = False,
) -> None:
self._base_config = config
self._max_turns = max_turns
self._max_price = max_price
self.agent_manager = AgentManager(
lambda: self._base_config, initial_agent=agent_name
)
self.tool_manager = ToolManager(lambda: self.config)
self.skill_manager = SkillManager(lambda: self.config)
self.format_handler = APIToolFormatHandler()
self.backend_factory = lambda: backend or self._select_backend()
self.backend = self.backend_factory()
self.message_observer = message_observer
self._last_observed_message_index: int = 0
self.enable_streaming = enable_streaming
self.middleware_pipeline = MiddlewarePipeline()
self._setup_middleware()
system_prompt = get_universal_system_prompt(
self.tool_manager, self.config, self.skill_manager, self.agent_manager
)
self.messages = [LLMMessage(role=Role.system, content=system_prompt)]
if self.message_observer:
self.message_observer(self.messages[0])
self._last_observed_message_index = 1
self.stats = AgentStats()
try:
active_model = config.get_active_model()
self.stats.input_price_per_million = active_model.input_price
self.stats.output_price_per_million = active_model.output_price
except ValueError:
pass
self.approval_callback: ApprovalCallback | None = None
self.user_input_callback: UserInputCallback | None = None
self.session_id = str(uuid4())
self._current_user_message_id: str | None = None
self.telemetry_client = TelemetryClient(config_getter=lambda: self.config)
self.session_logger = SessionLogger(config.session_logging, self.session_id)
self._teleport_service: TeleportService | None = None
thread = Thread(
target=migrate_sessions_entrypoint,
args=(config.session_logging,),
daemon=True,
name="migrate_sessions",
)
thread.start()
@property
def agent_profile(self) -> AgentProfile:
return self.agent_manager.active_profile
@property
def config(self) -> VibeConfig:
return self.agent_manager.config
@property
def auto_approve(self) -> bool:
return self.config.auto_approve
def set_tool_permission(
self, tool_name: str, permission: ToolPermission, save_permanently: bool = False
) -> None:
if save_permanently:
VibeConfig.save_updates({
"tools": {tool_name: {"permission": permission.value}}
})
if tool_name not in self.config.tools:
self.config.tools[tool_name] = BaseToolConfig()
self.config.tools[tool_name].permission = permission
self.tool_manager.invalidate_tool(tool_name)
def emit_new_session_telemetry(
self, entrypoint: Literal["cli", "acp", "programmatic"]
) -> None:
has_agents_md = has_agents_md_file(Path.cwd())
nb_skills = len(self.skill_manager.available_skills)
nb_mcp_servers = len(self.config.mcp_servers)
nb_models = len(self.config.models)
self.telemetry_client.send_new_session(
has_agents_md=has_agents_md,
nb_skills=nb_skills,
nb_mcp_servers=nb_mcp_servers,
nb_models=nb_models,
entrypoint=entrypoint,
)
def _select_backend(self) -> BackendLike:
active_model = self.config.get_active_model()
provider = self.config.get_provider_for_model(active_model)
timeout = self.config.api_timeout
return BACKEND_FACTORY[provider.backend](provider=provider, timeout=timeout)
def add_message(self, message: LLMMessage) -> None:
self.messages.append(message)
async def _save_messages(self) -> None:
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
async def _flush_new_messages(self) -> None:
await self._save_messages()
if not self.message_observer:
return
if self._last_observed_message_index >= len(self.messages):
return
for msg in self.messages[self._last_observed_message_index :]:
self.message_observer(msg)
self._last_observed_message_index = len(self.messages)
async def act(self, msg: str) -> AsyncGenerator[BaseEvent]:
self._clean_message_history()
async for event in self._conversation_loop(msg):
yield event
@property
def teleport_service(self) -> TeleportService:
if not _TELEPORT_AVAILABLE:
raise TeleportError(
"Teleport requires git to be installed. "
"Please install git and try again."
)
if self._teleport_service is None:
if _TeleportService is None:
raise TeleportError("_TeleportService is unexpectedly None")
self._teleport_service = _TeleportService(
session_logger=self.session_logger,
nuage_base_url=self.config.nuage_base_url,
nuage_workflow_id=self.config.nuage_workflow_id,
nuage_api_key=self.config.nuage_api_key,
)
return self._teleport_service
def teleport_to_vibe_nuage(
self, prompt: str | None
) -> AsyncGenerator[TeleportYieldEvent, TeleportPushResponseEvent | None]:
from vibe.core.teleport.nuage import TeleportSession
session = TeleportSession(
metadata={
"agent": self.agent_profile.name,
"model": self.config.active_model,
"stats": self.stats.model_dump(),
},
messages=[msg.model_dump(exclude_none=True) for msg in self.messages[1:]],
)
return self._teleport_generator(prompt, session)
async def _teleport_generator(
self, prompt: str | None, session: TeleportSession
) -> AsyncGenerator[TeleportYieldEvent, TeleportPushResponseEvent | None]:
from vibe.core.teleport.errors import ServiceTeleportError
try:
async with self.teleport_service:
gen = self.teleport_service.execute(prompt=prompt, session=session)
response: TeleportPushResponseEvent | None = None
while True:
try:
event = await gen.asend(response)
response = yield event
except StopAsyncIteration:
break
except ServiceTeleportError as e:
raise TeleportError(str(e)) from e
finally:
self._teleport_service = None
def _setup_middleware(self) -> None:
"""Configure middleware pipeline for this conversation."""
self.middleware_pipeline.clear()
if self._max_turns is not None:
self.middleware_pipeline.add(TurnLimitMiddleware(self._max_turns))
if self._max_price is not None:
self.middleware_pipeline.add(PriceLimitMiddleware(self._max_price))
if self.config.auto_compact_threshold > 0:
self.middleware_pipeline.add(
AutoCompactMiddleware(self.config.auto_compact_threshold)
)
if self.config.context_warnings:
self.middleware_pipeline.add(
ContextWarningMiddleware(0.5, self.config.auto_compact_threshold)
)
self.middleware_pipeline.add(PlanAgentMiddleware(lambda: self.agent_profile))
async def _handle_middleware_result(
self, result: MiddlewareResult
) -> AsyncGenerator[BaseEvent]:
match result.action:
case MiddlewareAction.STOP:
yield AssistantEvent(
content=f"<{VIBE_STOP_EVENT_TAG}>{result.reason}</{VIBE_STOP_EVENT_TAG}>",
stopped_by_middleware=True,
)
case MiddlewareAction.INJECT_MESSAGE:
if result.message:
injected_message = LLMMessage(
role=Role.user, content=result.message
)
self.messages.append(injected_message)
case MiddlewareAction.COMPACT:
old_tokens = result.metadata.get(
"old_tokens", self.stats.context_tokens
)
threshold = result.metadata.get(
"threshold", self.config.auto_compact_threshold
)
tool_call_id = str(uuid4())
yield CompactStartEvent(
tool_call_id=tool_call_id,
current_context_tokens=old_tokens,
threshold=threshold,
)
self.telemetry_client.send_auto_compact_triggered()
summary = await self.compact()
yield CompactEndEvent(
tool_call_id=tool_call_id,
old_context_tokens=old_tokens,
new_context_tokens=self.stats.context_tokens,
summary_length=len(summary),
)
case MiddlewareAction.CONTINUE:
pass
def _get_context(self) -> ConversationContext:
return ConversationContext(
messages=self.messages, stats=self.stats, config=self.config
)
def _get_extra_headers(self, provider: ProviderConfig) -> dict[str, str]:
headers: dict[str, str] = {
"user-agent": get_user_agent(provider.backend),
"x-affinity": self.session_id,
}
if (
provider.backend == Backend.MISTRAL
and self._current_user_message_id is not None
):
headers["metadata"] = json.dumps({
"message_id": self._current_user_message_id
})
return headers
async def _conversation_loop(self, user_msg: str) -> AsyncGenerator[BaseEvent]:
user_message = LLMMessage(role=Role.user, content=user_msg)
self.messages.append(user_message)
self.stats.steps += 1
self._current_user_message_id = user_message.message_id
if user_message.message_id is None:
raise AgentLoopError("User message must have a message_id")
yield UserMessageEvent(content=user_msg, message_id=user_message.message_id)
try:
should_break_loop = False
while not should_break_loop:
result = await self.middleware_pipeline.run_before_turn(
self._get_context()
)
async for event in self._handle_middleware_result(result):
yield event
if result.action == MiddlewareAction.STOP:
return
self.stats.steps += 1
user_cancelled = False
async for event in self._perform_llm_turn():
if is_user_cancellation_event(event):
user_cancelled = True
yield event
await self._flush_new_messages()
last_message = self.messages[-1]
should_break_loop = last_message.role != Role.tool
if user_cancelled:
return
finally:
await self._flush_new_messages()
async def _perform_llm_turn(self) -> AsyncGenerator[BaseEvent, None]:
if self.enable_streaming:
async for event in self._stream_assistant_events():
yield event
else:
assistant_event = await self._get_assistant_event()
if assistant_event.content:
yield assistant_event
last_message = self.messages[-1]
parsed = self.format_handler.parse_message(last_message)
resolved = self.format_handler.resolve_tool_calls(parsed, self.tool_manager)
if not resolved.tool_calls and not resolved.failed_calls:
return
async for event in self._handle_tool_calls(resolved):
yield event
async def _stream_assistant_events(
self,
) -> AsyncGenerator[AssistantEvent | ReasoningEvent]:
content_buffer = ""
reasoning_buffer = ""
chunks_with_content = 0
chunks_with_reasoning = 0
message_id: str | None = None
BATCH_SIZE = 5
async for chunk in self._chat_streaming():
if message_id is None:
message_id = chunk.message.message_id
if chunk.message.reasoning_content:
if content_buffer:
yield AssistantEvent(content=content_buffer, message_id=message_id)
content_buffer = ""
chunks_with_content = 0
reasoning_buffer += chunk.message.reasoning_content
chunks_with_reasoning += 1
if chunks_with_reasoning >= BATCH_SIZE:
yield ReasoningEvent(
content=reasoning_buffer, message_id=message_id
)
reasoning_buffer = ""
chunks_with_reasoning = 0
if chunk.message.content:
if reasoning_buffer:
yield ReasoningEvent(
content=reasoning_buffer, message_id=message_id
)
reasoning_buffer = ""
chunks_with_reasoning = 0
content_buffer += chunk.message.content
chunks_with_content += 1
if chunks_with_content >= BATCH_SIZE:
yield AssistantEvent(content=content_buffer, message_id=message_id)
content_buffer = ""
chunks_with_content = 0
if reasoning_buffer:
yield ReasoningEvent(content=reasoning_buffer, message_id=message_id)
if content_buffer:
yield AssistantEvent(content=content_buffer, message_id=message_id)
async def _get_assistant_event(self) -> AssistantEvent:
llm_result = await self._chat()
return AssistantEvent(
content=llm_result.message.content or "",
message_id=llm_result.message.message_id,
)
async def _emit_failed_tool_events(
self, failed_calls: list[FailedToolCall]
) -> AsyncGenerator[ToolResultEvent]:
for failed in failed_calls:
error_msg = f"<{TOOL_ERROR_TAG}>{failed.tool_name}: {failed.error}</{TOOL_ERROR_TAG}>"
yield ToolResultEvent(
tool_name=failed.tool_name,
tool_class=None,
error=error_msg,
tool_call_id=failed.call_id,
)
self.stats.tool_calls_failed += 1
self.messages.append(
self.format_handler.create_failed_tool_response_message(
failed, error_msg
)
)
async def _process_one_tool_call(
self, tool_call: ResolvedToolCall
) -> AsyncGenerator[ToolResultEvent | ToolStreamEvent]:
try:
tool_instance = self.tool_manager.get(tool_call.tool_name)
except Exception as exc:
error_msg = f"Error getting tool '{tool_call.tool_name}': {exc}"
yield ToolResultEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
error=error_msg,
tool_call_id=tool_call.call_id,
)
self._handle_tool_response(tool_call, error_msg, "failure")
return
decision = await self._should_execute_tool(
tool_instance, tool_call.validated_args, tool_call.call_id
)
if decision.verdict == ToolExecutionResponse.SKIP:
self.stats.tool_calls_rejected += 1
skip_reason = decision.feedback or str(
get_user_cancellation_message(
CancellationReason.TOOL_SKIPPED, tool_call.tool_name
)
)
yield ToolResultEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
skipped=True,
skip_reason=skip_reason,
tool_call_id=tool_call.call_id,
)
self._handle_tool_response(tool_call, skip_reason, "skipped", decision)
return
self.stats.tool_calls_agreed += 1
try:
start_time = time.perf_counter()
result_model = None
async for item in tool_instance.invoke(
ctx=InvokeContext(
tool_call_id=tool_call.call_id,
approval_callback=self.approval_callback,
agent_manager=self.agent_manager,
user_input_callback=self.user_input_callback,
),
**tool_call.args_dict,
):
if isinstance(item, ToolStreamEvent):
yield item
else:
result_model = item
duration = time.perf_counter() - start_time
if result_model is None:
raise ToolError("Tool did not yield a result")
result_dict = result_model.model_dump()
text = "\n".join(f"{k}: {v}" for k, v in result_dict.items())
self._handle_tool_response(
tool_call, text, "success", decision, result_dict
)
yield ToolResultEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
result=result_model,
duration=duration,
tool_call_id=tool_call.call_id,
)
self.stats.tool_calls_succeeded += 1
except asyncio.CancelledError:
cancel = str(
get_user_cancellation_message(CancellationReason.TOOL_INTERRUPTED)
)
yield ToolResultEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
error=cancel,
tool_call_id=tool_call.call_id,
)
self._handle_tool_response(tool_call, cancel, "failure", decision)
raise
except (ToolError, ToolPermissionError) as exc:
error_msg = f"<{TOOL_ERROR_TAG}>{tool_instance.get_name()} failed: {exc}</{TOOL_ERROR_TAG}>"
yield ToolResultEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
error=error_msg,
tool_call_id=tool_call.call_id,
)
if isinstance(exc, ToolPermissionError):
self.stats.tool_calls_agreed -= 1
self.stats.tool_calls_rejected += 1
else:
self.stats.tool_calls_failed += 1
self._handle_tool_response(tool_call, error_msg, "failure", decision)
async def _handle_tool_calls(
self, resolved: ResolvedMessage
) -> AsyncGenerator[ToolCallEvent | ToolResultEvent | ToolStreamEvent]:
async for event in self._emit_failed_tool_events(resolved.failed_calls):
yield event
for tool_call in resolved.tool_calls:
yield ToolCallEvent(
tool_name=tool_call.tool_name,
tool_class=tool_call.tool_class,
args=tool_call.validated_args,
tool_call_id=tool_call.call_id,
)
async for event in self._process_one_tool_call(tool_call):
yield event
def _handle_tool_response(
self,
tool_call: ResolvedToolCall,
text: str,
status: Literal["success", "failure", "skipped"],
decision: ToolDecision | None = None,
result: dict[str, Any] | None = None,
) -> None:
self.messages.append(
LLMMessage.model_validate(
self.format_handler.create_tool_response_message(tool_call, text)
)
)
self.telemetry_client.send_tool_call_finished(
tool_call=tool_call,
agent_profile_name=self.agent_profile.name,
status=status,
decision=decision,
result=result,
)
async def _chat(self, max_tokens: int | None = None) -> LLMChunk:
active_model = self.config.get_active_model()
provider = self.config.get_provider_for_model(active_model)
available_tools = self.format_handler.get_available_tools(self.tool_manager)
tool_choice = self.format_handler.get_tool_choice()
try:
start_time = time.perf_counter()
result = await self.backend.complete(
model=active_model,
messages=self.messages,
temperature=active_model.temperature,
tools=available_tools,
tool_choice=tool_choice,
extra_headers=self._get_extra_headers(provider),
max_tokens=max_tokens,
)
end_time = time.perf_counter()
if result.usage is None:
raise AgentLoopLLMResponseError(
"Usage data missing in non-streaming completion response"
)
self._update_stats(usage=result.usage, time_seconds=end_time - start_time)
processed_message = self.format_handler.process_api_response_message(
result.message
)
self.messages.append(processed_message)
return LLMChunk(message=processed_message, usage=result.usage)
except Exception as e:
if _should_raise_rate_limit_error(e):
raise RateLimitError(provider.name, active_model.name) from e
raise RuntimeError(
f"API error from {provider.name} (model: {active_model.name}): {e}"
) from e
async def _chat_streaming(
self, max_tokens: int | None = None
) -> AsyncGenerator[LLMChunk]:
active_model = self.config.get_active_model()
provider = self.config.get_provider_for_model(active_model)
available_tools = self.format_handler.get_available_tools(self.tool_manager)
tool_choice = self.format_handler.get_tool_choice()
try:
start_time = time.perf_counter()
usage = LLMUsage()
chunk_agg = LLMChunk(message=LLMMessage(role=Role.assistant))
async for chunk in self.backend.complete_streaming(
model=active_model,
messages=self.messages,
temperature=active_model.temperature,
tools=available_tools,
tool_choice=tool_choice,
extra_headers=self._get_extra_headers(provider),
max_tokens=max_tokens,
):
processed_message = self.format_handler.process_api_response_message(
chunk.message
)
processed_chunk = LLMChunk(message=processed_message, usage=chunk.usage)
chunk_agg += processed_chunk
usage += chunk.usage or LLMUsage()
yield processed_chunk
end_time = time.perf_counter()
if chunk_agg.usage is None:
raise AgentLoopLLMResponseError(
"Usage data missing in final chunk of streamed completion"
)
self._update_stats(usage=usage, time_seconds=end_time - start_time)
self.messages.append(chunk_agg.message)
except Exception as e:
if _should_raise_rate_limit_error(e):
raise RateLimitError(provider.name, active_model.name) from e
raise RuntimeError(
f"API error from {provider.name} (model: {active_model.name}): {e}"
) from e
def _update_stats(self, usage: LLMUsage, time_seconds: float) -> None:
self.stats.last_turn_duration = time_seconds
self.stats.last_turn_prompt_tokens = usage.prompt_tokens
self.stats.last_turn_completion_tokens = usage.completion_tokens
self.stats.session_prompt_tokens += usage.prompt_tokens
self.stats.session_completion_tokens += usage.completion_tokens
self.stats.context_tokens = usage.prompt_tokens + usage.completion_tokens
if time_seconds > 0 and usage.completion_tokens > 0:
self.stats.tokens_per_second = usage.completion_tokens / time_seconds
async def _should_execute_tool(
self, tool: BaseTool, args: BaseModel, tool_call_id: str
) -> ToolDecision:
if self.auto_approve:
return ToolDecision(
verdict=ToolExecutionResponse.EXECUTE,
approval_type=ToolPermission.ALWAYS,
)
allowlist_denylist_result = tool.check_allowlist_denylist(args)
if allowlist_denylist_result == ToolPermission.ALWAYS:
return ToolDecision(
verdict=ToolExecutionResponse.EXECUTE,
approval_type=ToolPermission.ALWAYS,
)
elif allowlist_denylist_result == ToolPermission.NEVER:
denylist_patterns = tool.config.denylist
denylist_str = ", ".join(repr(pattern) for pattern in denylist_patterns)
return ToolDecision(
verdict=ToolExecutionResponse.SKIP,
approval_type=ToolPermission.NEVER,
feedback=f"Tool '{tool.get_name()}' blocked by denylist: [{denylist_str}]",
)
tool_name = tool.get_name()
perm = self.tool_manager.get_tool_config(tool_name).permission
if perm is ToolPermission.ALWAYS:
return ToolDecision(
verdict=ToolExecutionResponse.EXECUTE,
approval_type=ToolPermission.ALWAYS,
)
if perm is ToolPermission.NEVER:
return ToolDecision(
verdict=ToolExecutionResponse.SKIP,
approval_type=ToolPermission.NEVER,
feedback=f"Tool '{tool_name}' is permanently disabled",
)
return await self._ask_approval(tool_name, args, tool_call_id)
async def _ask_approval(
self, tool_name: str, args: BaseModel, tool_call_id: str
) -> ToolDecision:
if not self.approval_callback:
return ToolDecision(
verdict=ToolExecutionResponse.SKIP,
approval_type=ToolPermission.ASK,
feedback="Tool execution not permitted.",
)
if asyncio.iscoroutinefunction(self.approval_callback):
async_callback = cast(AsyncApprovalCallback, self.approval_callback)
response, feedback = await async_callback(tool_name, args, tool_call_id)
else:
sync_callback = cast(SyncApprovalCallback, self.approval_callback)
response, feedback = sync_callback(tool_name, args, tool_call_id)
match response:
case ApprovalResponse.YES:
return ToolDecision(
verdict=ToolExecutionResponse.EXECUTE,
approval_type=ToolPermission.ASK,
feedback=feedback,
)
case ApprovalResponse.NO:
return ToolDecision(
verdict=ToolExecutionResponse.SKIP,
approval_type=ToolPermission.ASK,
feedback=feedback,
)
def _clean_message_history(self) -> None:
ACCEPTABLE_HISTORY_SIZE = 2
if len(self.messages) < ACCEPTABLE_HISTORY_SIZE:
return
self._fill_missing_tool_responses()
self._ensure_assistant_after_tools()
def _fill_missing_tool_responses(self) -> None:
i = 1
while i < len(self.messages): # noqa: PLR1702
msg = self.messages[i]
if msg.role == "assistant" and msg.tool_calls:
expected_responses = len(msg.tool_calls)
if expected_responses > 0:
actual_responses = 0
j = i + 1
while j < len(self.messages) and self.messages[j].role == "tool":
actual_responses += 1
j += 1
if actual_responses < expected_responses:
insertion_point = i + 1 + actual_responses
for call_idx in range(actual_responses, expected_responses):
tool_call_data = msg.tool_calls[call_idx]
empty_response = LLMMessage(
role=Role.tool,
tool_call_id=tool_call_data.id or "",
name=(tool_call_data.function.name or "")
if tool_call_data.function
else "",
content=str(
get_user_cancellation_message(
CancellationReason.TOOL_NO_RESPONSE
)
),
)
self.messages.insert(insertion_point, empty_response)
insertion_point += 1
i = i + 1 + expected_responses
continue
i += 1
def _ensure_assistant_after_tools(self) -> None:
MIN_MESSAGE_SIZE = 2
if len(self.messages) < MIN_MESSAGE_SIZE:
return
last_msg = self.messages[-1]
if last_msg.role is Role.tool:
empty_assistant_msg = LLMMessage(role=Role.assistant, content="Understood.")
self.messages.append(empty_assistant_msg)
def _reset_session(self) -> None:
self.session_id = str(uuid4())
self.session_logger.reset_session(self.session_id)
def set_approval_callback(self, callback: ApprovalCallback) -> None:
self.approval_callback = callback
def set_user_input_callback(self, callback: UserInputCallback) -> None:
self.user_input_callback = callback
async def clear_history(self) -> None:
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
self.messages = self.messages[:1]
self.stats = AgentStats()
self.stats.trigger_listeners()
try:
active_model = self.config.get_active_model()
self.stats.update_pricing(
active_model.input_price, active_model.output_price
)
except ValueError:
pass
self.middleware_pipeline.reset()
self.tool_manager.reset_all()
self._reset_session()
async def compact(self) -> str:
try:
self._clean_message_history()
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
summary_request = UtilityPrompt.COMPACT.read()
self.messages.append(LLMMessage(role=Role.user, content=summary_request))
self.stats.steps += 1
summary_result = await self._chat()
if summary_result.usage is None:
raise AgentLoopLLMResponseError(
"Usage data missing in compaction summary response"
)
summary_content = summary_result.message.content or ""
system_message = self.messages[0]
summary_message = LLMMessage(role=Role.user, content=summary_content)
self.messages = [system_message, summary_message]
self._last_observed_message_index = 1
active_model = self.config.get_active_model()
provider = self.config.get_provider_for_model(active_model)
actual_context_tokens = await self.backend.count_tokens(
model=active_model,
messages=self.messages,
tools=self.format_handler.get_available_tools(self.tool_manager),
extra_headers={"user-agent": get_user_agent(provider.backend)},
)
self.stats.context_tokens = actual_context_tokens
self._reset_session()
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
self.middleware_pipeline.reset(reset_reason=ResetReason.COMPACT)
return summary_content or ""
except Exception:
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
raise
async def switch_agent(self, agent_name: str) -> None:
if agent_name == self.agent_profile.name:
return
self.agent_manager.switch_profile(agent_name)
await self.reload_with_initial_messages(reset_middleware=False)
async def reload_with_initial_messages(
self,
base_config: VibeConfig | None = None,
max_turns: int | None = None,
max_price: float | None = None,
reset_middleware: bool = True,
) -> None:
# Force an immediate yield to allow the UI to update before heavy sync work.
# When there are no messages, save_interaction returns early without any await,
# so the coroutine would run synchronously through ToolManager, SkillManager,
# and system prompt generation without yielding control to the event loop.
await asyncio.sleep(0)
await self.session_logger.save_interaction(
self.messages,
self.stats,
self._base_config,
self.tool_manager,
self.agent_profile,
)
if base_config is not None:
self._base_config = base_config
self.agent_manager.invalidate_config()
self.backend = self.backend_factory()
if max_turns is not None:
self._max_turns = max_turns
if max_price is not None:
self._max_price = max_price
self.tool_manager = ToolManager(lambda: self.config)
self.skill_manager = SkillManager(lambda: self.config)
new_system_prompt = get_universal_system_prompt(
self.tool_manager, self.config, self.skill_manager, self.agent_manager
)
self.messages = [
LLMMessage(role=Role.system, content=new_system_prompt),
*[msg for msg in self.messages if msg.role != Role.system],
]
if len(self.messages) == 1:
self.stats.reset_context_state()
try:
active_model = self.config.get_active_model()
self.stats.update_pricing(
active_model.input_price, active_model.output_price
)
except ValueError:
pass
if reset_middleware:
self._setup_middleware()