mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai> Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai> Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
1076 lines
38 KiB
Python
1076 lines
38 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator, Callable
|
|
from enum import StrEnum, auto
|
|
from http import HTTPStatus
|
|
import json
|
|
from pathlib import Path
|
|
from threading import Thread
|
|
import time
|
|
from typing import TYPE_CHECKING, Any, Literal, cast
|
|
from uuid import uuid4
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from vibe.core.agents.manager import AgentManager
|
|
from vibe.core.agents.models import AgentProfile, BuiltinAgentName
|
|
from vibe.core.config import Backend, ProviderConfig, VibeConfig
|
|
from vibe.core.llm.backend.factory import BACKEND_FACTORY
|
|
from vibe.core.llm.exceptions import BackendError
|
|
from vibe.core.llm.format import (
|
|
APIToolFormatHandler,
|
|
FailedToolCall,
|
|
ResolvedMessage,
|
|
ResolvedToolCall,
|
|
)
|
|
from vibe.core.llm.types import BackendLike
|
|
from vibe.core.middleware import (
|
|
AutoCompactMiddleware,
|
|
ContextWarningMiddleware,
|
|
ConversationContext,
|
|
MiddlewareAction,
|
|
MiddlewarePipeline,
|
|
MiddlewareResult,
|
|
PlanAgentMiddleware,
|
|
PriceLimitMiddleware,
|
|
ResetReason,
|
|
TurnLimitMiddleware,
|
|
)
|
|
from vibe.core.prompts import UtilityPrompt
|
|
from vibe.core.session.session_logger import SessionLogger
|
|
from vibe.core.session.session_migration import migrate_sessions_entrypoint
|
|
from vibe.core.skills.manager import SkillManager
|
|
from vibe.core.system_prompt import get_universal_system_prompt
|
|
from vibe.core.telemetry.send import TelemetryClient
|
|
from vibe.core.tools.base import (
|
|
BaseTool,
|
|
BaseToolConfig,
|
|
InvokeContext,
|
|
ToolError,
|
|
ToolPermission,
|
|
ToolPermissionError,
|
|
)
|
|
from vibe.core.tools.manager import ToolManager
|
|
from vibe.core.trusted_folders import has_agents_md_file
|
|
from vibe.core.types import (
|
|
AgentStats,
|
|
ApprovalCallback,
|
|
ApprovalResponse,
|
|
AssistantEvent,
|
|
AsyncApprovalCallback,
|
|
BaseEvent,
|
|
CompactEndEvent,
|
|
CompactStartEvent,
|
|
LLMChunk,
|
|
LLMMessage,
|
|
LLMUsage,
|
|
RateLimitError,
|
|
ReasoningEvent,
|
|
Role,
|
|
SyncApprovalCallback,
|
|
ToolCallEvent,
|
|
ToolResultEvent,
|
|
ToolStreamEvent,
|
|
UserInputCallback,
|
|
UserMessageEvent,
|
|
)
|
|
from vibe.core.utils import (
|
|
TOOL_ERROR_TAG,
|
|
VIBE_STOP_EVENT_TAG,
|
|
CancellationReason,
|
|
get_user_agent,
|
|
get_user_cancellation_message,
|
|
is_user_cancellation_event,
|
|
)
|
|
|
|
try:
|
|
from vibe.core.teleport.teleport import TeleportService as _TeleportService
|
|
|
|
_TELEPORT_AVAILABLE = True
|
|
except ImportError:
|
|
_TELEPORT_AVAILABLE = False
|
|
_TeleportService = None
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe.core.teleport.nuage import TeleportSession
|
|
from vibe.core.teleport.teleport import TeleportService
|
|
from vibe.core.teleport.types import TeleportPushResponseEvent, TeleportYieldEvent
|
|
|
|
|
|
class ToolExecutionResponse(StrEnum):
|
|
SKIP = auto()
|
|
EXECUTE = auto()
|
|
|
|
|
|
class ToolDecision(BaseModel):
|
|
verdict: ToolExecutionResponse
|
|
approval_type: ToolPermission
|
|
feedback: str | None = None
|
|
|
|
|
|
class AgentLoopError(Exception):
|
|
"""Base exception for AgentLoop errors."""
|
|
|
|
|
|
class AgentLoopStateError(AgentLoopError):
|
|
"""Raised when agent loop is in an invalid state."""
|
|
|
|
|
|
class AgentLoopLLMResponseError(AgentLoopError):
|
|
"""Raised when LLM response is malformed or missing expected data."""
|
|
|
|
|
|
class TeleportError(AgentLoopError):
|
|
"""Raised when teleport to Vibe Nuage fails."""
|
|
|
|
|
|
def _should_raise_rate_limit_error(e: Exception) -> bool:
|
|
return isinstance(e, BackendError) and e.status == HTTPStatus.TOO_MANY_REQUESTS
|
|
|
|
|
|
class AgentLoop:
|
|
def __init__(
|
|
self,
|
|
config: VibeConfig,
|
|
agent_name: str = BuiltinAgentName.DEFAULT,
|
|
message_observer: Callable[[LLMMessage], None] | None = None,
|
|
max_turns: int | None = None,
|
|
max_price: float | None = None,
|
|
backend: BackendLike | None = None,
|
|
enable_streaming: bool = False,
|
|
) -> None:
|
|
self._base_config = config
|
|
self._max_turns = max_turns
|
|
self._max_price = max_price
|
|
|
|
self.agent_manager = AgentManager(
|
|
lambda: self._base_config, initial_agent=agent_name
|
|
)
|
|
self.tool_manager = ToolManager(lambda: self.config)
|
|
self.skill_manager = SkillManager(lambda: self.config)
|
|
self.format_handler = APIToolFormatHandler()
|
|
|
|
self.backend_factory = lambda: backend or self._select_backend()
|
|
self.backend = self.backend_factory()
|
|
|
|
self.message_observer = message_observer
|
|
self._last_observed_message_index: int = 0
|
|
self.enable_streaming = enable_streaming
|
|
self.middleware_pipeline = MiddlewarePipeline()
|
|
self._setup_middleware()
|
|
|
|
system_prompt = get_universal_system_prompt(
|
|
self.tool_manager, self.config, self.skill_manager, self.agent_manager
|
|
)
|
|
self.messages = [LLMMessage(role=Role.system, content=system_prompt)]
|
|
|
|
if self.message_observer:
|
|
self.message_observer(self.messages[0])
|
|
self._last_observed_message_index = 1
|
|
|
|
self.stats = AgentStats()
|
|
try:
|
|
active_model = config.get_active_model()
|
|
self.stats.input_price_per_million = active_model.input_price
|
|
self.stats.output_price_per_million = active_model.output_price
|
|
except ValueError:
|
|
pass
|
|
|
|
self.approval_callback: ApprovalCallback | None = None
|
|
self.user_input_callback: UserInputCallback | None = None
|
|
|
|
self.session_id = str(uuid4())
|
|
self._current_user_message_id: str | None = None
|
|
|
|
self.telemetry_client = TelemetryClient(config_getter=lambda: self.config)
|
|
self.session_logger = SessionLogger(config.session_logging, self.session_id)
|
|
self._teleport_service: TeleportService | None = None
|
|
|
|
thread = Thread(
|
|
target=migrate_sessions_entrypoint,
|
|
args=(config.session_logging,),
|
|
daemon=True,
|
|
name="migrate_sessions",
|
|
)
|
|
thread.start()
|
|
|
|
@property
|
|
def agent_profile(self) -> AgentProfile:
|
|
return self.agent_manager.active_profile
|
|
|
|
@property
|
|
def config(self) -> VibeConfig:
|
|
return self.agent_manager.config
|
|
|
|
@property
|
|
def auto_approve(self) -> bool:
|
|
return self.config.auto_approve
|
|
|
|
def set_tool_permission(
|
|
self, tool_name: str, permission: ToolPermission, save_permanently: bool = False
|
|
) -> None:
|
|
if save_permanently:
|
|
VibeConfig.save_updates({
|
|
"tools": {tool_name: {"permission": permission.value}}
|
|
})
|
|
|
|
if tool_name not in self.config.tools:
|
|
self.config.tools[tool_name] = BaseToolConfig()
|
|
|
|
self.config.tools[tool_name].permission = permission
|
|
self.tool_manager.invalidate_tool(tool_name)
|
|
|
|
def emit_new_session_telemetry(
|
|
self, entrypoint: Literal["cli", "acp", "programmatic"]
|
|
) -> None:
|
|
has_agents_md = has_agents_md_file(Path.cwd())
|
|
nb_skills = len(self.skill_manager.available_skills)
|
|
nb_mcp_servers = len(self.config.mcp_servers)
|
|
nb_models = len(self.config.models)
|
|
self.telemetry_client.send_new_session(
|
|
has_agents_md=has_agents_md,
|
|
nb_skills=nb_skills,
|
|
nb_mcp_servers=nb_mcp_servers,
|
|
nb_models=nb_models,
|
|
entrypoint=entrypoint,
|
|
)
|
|
|
|
def _select_backend(self) -> BackendLike:
|
|
active_model = self.config.get_active_model()
|
|
provider = self.config.get_provider_for_model(active_model)
|
|
timeout = self.config.api_timeout
|
|
return BACKEND_FACTORY[provider.backend](provider=provider, timeout=timeout)
|
|
|
|
def add_message(self, message: LLMMessage) -> None:
|
|
self.messages.append(message)
|
|
|
|
async def _save_messages(self) -> None:
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
|
|
async def _flush_new_messages(self) -> None:
|
|
await self._save_messages()
|
|
|
|
if not self.message_observer:
|
|
return
|
|
|
|
if self._last_observed_message_index >= len(self.messages):
|
|
return
|
|
|
|
for msg in self.messages[self._last_observed_message_index :]:
|
|
self.message_observer(msg)
|
|
self._last_observed_message_index = len(self.messages)
|
|
|
|
async def act(self, msg: str) -> AsyncGenerator[BaseEvent]:
|
|
self._clean_message_history()
|
|
async for event in self._conversation_loop(msg):
|
|
yield event
|
|
|
|
@property
|
|
def teleport_service(self) -> TeleportService:
|
|
if not _TELEPORT_AVAILABLE:
|
|
raise TeleportError(
|
|
"Teleport requires git to be installed. "
|
|
"Please install git and try again."
|
|
)
|
|
|
|
if self._teleport_service is None:
|
|
if _TeleportService is None:
|
|
raise TeleportError("_TeleportService is unexpectedly None")
|
|
self._teleport_service = _TeleportService(
|
|
session_logger=self.session_logger,
|
|
nuage_base_url=self.config.nuage_base_url,
|
|
nuage_workflow_id=self.config.nuage_workflow_id,
|
|
nuage_api_key=self.config.nuage_api_key,
|
|
)
|
|
return self._teleport_service
|
|
|
|
def teleport_to_vibe_nuage(
|
|
self, prompt: str | None
|
|
) -> AsyncGenerator[TeleportYieldEvent, TeleportPushResponseEvent | None]:
|
|
from vibe.core.teleport.nuage import TeleportSession
|
|
|
|
session = TeleportSession(
|
|
metadata={
|
|
"agent": self.agent_profile.name,
|
|
"model": self.config.active_model,
|
|
"stats": self.stats.model_dump(),
|
|
},
|
|
messages=[msg.model_dump(exclude_none=True) for msg in self.messages[1:]],
|
|
)
|
|
return self._teleport_generator(prompt, session)
|
|
|
|
async def _teleport_generator(
|
|
self, prompt: str | None, session: TeleportSession
|
|
) -> AsyncGenerator[TeleportYieldEvent, TeleportPushResponseEvent | None]:
|
|
from vibe.core.teleport.errors import ServiceTeleportError
|
|
|
|
try:
|
|
async with self.teleport_service:
|
|
gen = self.teleport_service.execute(prompt=prompt, session=session)
|
|
response: TeleportPushResponseEvent | None = None
|
|
while True:
|
|
try:
|
|
event = await gen.asend(response)
|
|
response = yield event
|
|
except StopAsyncIteration:
|
|
break
|
|
except ServiceTeleportError as e:
|
|
raise TeleportError(str(e)) from e
|
|
finally:
|
|
self._teleport_service = None
|
|
|
|
def _setup_middleware(self) -> None:
|
|
"""Configure middleware pipeline for this conversation."""
|
|
self.middleware_pipeline.clear()
|
|
|
|
if self._max_turns is not None:
|
|
self.middleware_pipeline.add(TurnLimitMiddleware(self._max_turns))
|
|
|
|
if self._max_price is not None:
|
|
self.middleware_pipeline.add(PriceLimitMiddleware(self._max_price))
|
|
|
|
if self.config.auto_compact_threshold > 0:
|
|
self.middleware_pipeline.add(
|
|
AutoCompactMiddleware(self.config.auto_compact_threshold)
|
|
)
|
|
if self.config.context_warnings:
|
|
self.middleware_pipeline.add(
|
|
ContextWarningMiddleware(0.5, self.config.auto_compact_threshold)
|
|
)
|
|
|
|
self.middleware_pipeline.add(PlanAgentMiddleware(lambda: self.agent_profile))
|
|
|
|
async def _handle_middleware_result(
|
|
self, result: MiddlewareResult
|
|
) -> AsyncGenerator[BaseEvent]:
|
|
match result.action:
|
|
case MiddlewareAction.STOP:
|
|
yield AssistantEvent(
|
|
content=f"<{VIBE_STOP_EVENT_TAG}>{result.reason}</{VIBE_STOP_EVENT_TAG}>",
|
|
stopped_by_middleware=True,
|
|
)
|
|
|
|
case MiddlewareAction.INJECT_MESSAGE:
|
|
if result.message:
|
|
injected_message = LLMMessage(
|
|
role=Role.user, content=result.message
|
|
)
|
|
self.messages.append(injected_message)
|
|
|
|
case MiddlewareAction.COMPACT:
|
|
old_tokens = result.metadata.get(
|
|
"old_tokens", self.stats.context_tokens
|
|
)
|
|
threshold = result.metadata.get(
|
|
"threshold", self.config.auto_compact_threshold
|
|
)
|
|
tool_call_id = str(uuid4())
|
|
|
|
yield CompactStartEvent(
|
|
tool_call_id=tool_call_id,
|
|
current_context_tokens=old_tokens,
|
|
threshold=threshold,
|
|
)
|
|
self.telemetry_client.send_auto_compact_triggered()
|
|
|
|
summary = await self.compact()
|
|
|
|
yield CompactEndEvent(
|
|
tool_call_id=tool_call_id,
|
|
old_context_tokens=old_tokens,
|
|
new_context_tokens=self.stats.context_tokens,
|
|
summary_length=len(summary),
|
|
)
|
|
|
|
case MiddlewareAction.CONTINUE:
|
|
pass
|
|
|
|
def _get_context(self) -> ConversationContext:
|
|
return ConversationContext(
|
|
messages=self.messages, stats=self.stats, config=self.config
|
|
)
|
|
|
|
def _get_extra_headers(self, provider: ProviderConfig) -> dict[str, str]:
|
|
headers: dict[str, str] = {
|
|
"user-agent": get_user_agent(provider.backend),
|
|
"x-affinity": self.session_id,
|
|
}
|
|
if (
|
|
provider.backend == Backend.MISTRAL
|
|
and self._current_user_message_id is not None
|
|
):
|
|
headers["metadata"] = json.dumps({
|
|
"message_id": self._current_user_message_id
|
|
})
|
|
return headers
|
|
|
|
async def _conversation_loop(self, user_msg: str) -> AsyncGenerator[BaseEvent]:
|
|
user_message = LLMMessage(role=Role.user, content=user_msg)
|
|
self.messages.append(user_message)
|
|
self.stats.steps += 1
|
|
self._current_user_message_id = user_message.message_id
|
|
|
|
if user_message.message_id is None:
|
|
raise AgentLoopError("User message must have a message_id")
|
|
|
|
yield UserMessageEvent(content=user_msg, message_id=user_message.message_id)
|
|
|
|
try:
|
|
should_break_loop = False
|
|
while not should_break_loop:
|
|
result = await self.middleware_pipeline.run_before_turn(
|
|
self._get_context()
|
|
)
|
|
async for event in self._handle_middleware_result(result):
|
|
yield event
|
|
|
|
if result.action == MiddlewareAction.STOP:
|
|
return
|
|
|
|
self.stats.steps += 1
|
|
user_cancelled = False
|
|
async for event in self._perform_llm_turn():
|
|
if is_user_cancellation_event(event):
|
|
user_cancelled = True
|
|
yield event
|
|
await self._flush_new_messages()
|
|
|
|
last_message = self.messages[-1]
|
|
should_break_loop = last_message.role != Role.tool
|
|
|
|
if user_cancelled:
|
|
return
|
|
|
|
finally:
|
|
await self._flush_new_messages()
|
|
|
|
async def _perform_llm_turn(self) -> AsyncGenerator[BaseEvent, None]:
|
|
if self.enable_streaming:
|
|
async for event in self._stream_assistant_events():
|
|
yield event
|
|
else:
|
|
assistant_event = await self._get_assistant_event()
|
|
if assistant_event.content:
|
|
yield assistant_event
|
|
|
|
last_message = self.messages[-1]
|
|
|
|
parsed = self.format_handler.parse_message(last_message)
|
|
resolved = self.format_handler.resolve_tool_calls(parsed, self.tool_manager)
|
|
|
|
if not resolved.tool_calls and not resolved.failed_calls:
|
|
return
|
|
|
|
async for event in self._handle_tool_calls(resolved):
|
|
yield event
|
|
|
|
async def _stream_assistant_events(
|
|
self,
|
|
) -> AsyncGenerator[AssistantEvent | ReasoningEvent]:
|
|
content_buffer = ""
|
|
reasoning_buffer = ""
|
|
chunks_with_content = 0
|
|
chunks_with_reasoning = 0
|
|
message_id: str | None = None
|
|
BATCH_SIZE = 5
|
|
|
|
async for chunk in self._chat_streaming():
|
|
if message_id is None:
|
|
message_id = chunk.message.message_id
|
|
|
|
if chunk.message.reasoning_content:
|
|
if content_buffer:
|
|
yield AssistantEvent(content=content_buffer, message_id=message_id)
|
|
content_buffer = ""
|
|
chunks_with_content = 0
|
|
|
|
reasoning_buffer += chunk.message.reasoning_content
|
|
chunks_with_reasoning += 1
|
|
|
|
if chunks_with_reasoning >= BATCH_SIZE:
|
|
yield ReasoningEvent(
|
|
content=reasoning_buffer, message_id=message_id
|
|
)
|
|
reasoning_buffer = ""
|
|
chunks_with_reasoning = 0
|
|
|
|
if chunk.message.content:
|
|
if reasoning_buffer:
|
|
yield ReasoningEvent(
|
|
content=reasoning_buffer, message_id=message_id
|
|
)
|
|
reasoning_buffer = ""
|
|
chunks_with_reasoning = 0
|
|
|
|
content_buffer += chunk.message.content
|
|
chunks_with_content += 1
|
|
|
|
if chunks_with_content >= BATCH_SIZE:
|
|
yield AssistantEvent(content=content_buffer, message_id=message_id)
|
|
content_buffer = ""
|
|
chunks_with_content = 0
|
|
|
|
if reasoning_buffer:
|
|
yield ReasoningEvent(content=reasoning_buffer, message_id=message_id)
|
|
|
|
if content_buffer:
|
|
yield AssistantEvent(content=content_buffer, message_id=message_id)
|
|
|
|
async def _get_assistant_event(self) -> AssistantEvent:
|
|
llm_result = await self._chat()
|
|
return AssistantEvent(
|
|
content=llm_result.message.content or "",
|
|
message_id=llm_result.message.message_id,
|
|
)
|
|
|
|
async def _emit_failed_tool_events(
|
|
self, failed_calls: list[FailedToolCall]
|
|
) -> AsyncGenerator[ToolResultEvent]:
|
|
for failed in failed_calls:
|
|
error_msg = f"<{TOOL_ERROR_TAG}>{failed.tool_name}: {failed.error}</{TOOL_ERROR_TAG}>"
|
|
yield ToolResultEvent(
|
|
tool_name=failed.tool_name,
|
|
tool_class=None,
|
|
error=error_msg,
|
|
tool_call_id=failed.call_id,
|
|
)
|
|
self.stats.tool_calls_failed += 1
|
|
self.messages.append(
|
|
self.format_handler.create_failed_tool_response_message(
|
|
failed, error_msg
|
|
)
|
|
)
|
|
|
|
async def _process_one_tool_call(
|
|
self, tool_call: ResolvedToolCall
|
|
) -> AsyncGenerator[ToolResultEvent | ToolStreamEvent]:
|
|
try:
|
|
tool_instance = self.tool_manager.get(tool_call.tool_name)
|
|
except Exception as exc:
|
|
error_msg = f"Error getting tool '{tool_call.tool_name}': {exc}"
|
|
yield ToolResultEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
error=error_msg,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
self._handle_tool_response(tool_call, error_msg, "failure")
|
|
return
|
|
|
|
decision = await self._should_execute_tool(
|
|
tool_instance, tool_call.validated_args, tool_call.call_id
|
|
)
|
|
|
|
if decision.verdict == ToolExecutionResponse.SKIP:
|
|
self.stats.tool_calls_rejected += 1
|
|
skip_reason = decision.feedback or str(
|
|
get_user_cancellation_message(
|
|
CancellationReason.TOOL_SKIPPED, tool_call.tool_name
|
|
)
|
|
)
|
|
yield ToolResultEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
skipped=True,
|
|
skip_reason=skip_reason,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
self._handle_tool_response(tool_call, skip_reason, "skipped", decision)
|
|
return
|
|
|
|
self.stats.tool_calls_agreed += 1
|
|
|
|
try:
|
|
start_time = time.perf_counter()
|
|
result_model = None
|
|
async for item in tool_instance.invoke(
|
|
ctx=InvokeContext(
|
|
tool_call_id=tool_call.call_id,
|
|
approval_callback=self.approval_callback,
|
|
agent_manager=self.agent_manager,
|
|
user_input_callback=self.user_input_callback,
|
|
),
|
|
**tool_call.args_dict,
|
|
):
|
|
if isinstance(item, ToolStreamEvent):
|
|
yield item
|
|
else:
|
|
result_model = item
|
|
|
|
duration = time.perf_counter() - start_time
|
|
if result_model is None:
|
|
raise ToolError("Tool did not yield a result")
|
|
|
|
result_dict = result_model.model_dump()
|
|
text = "\n".join(f"{k}: {v}" for k, v in result_dict.items())
|
|
self._handle_tool_response(
|
|
tool_call, text, "success", decision, result_dict
|
|
)
|
|
yield ToolResultEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
result=result_model,
|
|
duration=duration,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
self.stats.tool_calls_succeeded += 1
|
|
|
|
except asyncio.CancelledError:
|
|
cancel = str(
|
|
get_user_cancellation_message(CancellationReason.TOOL_INTERRUPTED)
|
|
)
|
|
yield ToolResultEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
error=cancel,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
self._handle_tool_response(tool_call, cancel, "failure", decision)
|
|
raise
|
|
|
|
except (ToolError, ToolPermissionError) as exc:
|
|
error_msg = f"<{TOOL_ERROR_TAG}>{tool_instance.get_name()} failed: {exc}</{TOOL_ERROR_TAG}>"
|
|
yield ToolResultEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
error=error_msg,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
if isinstance(exc, ToolPermissionError):
|
|
self.stats.tool_calls_agreed -= 1
|
|
self.stats.tool_calls_rejected += 1
|
|
else:
|
|
self.stats.tool_calls_failed += 1
|
|
self._handle_tool_response(tool_call, error_msg, "failure", decision)
|
|
|
|
async def _handle_tool_calls(
|
|
self, resolved: ResolvedMessage
|
|
) -> AsyncGenerator[ToolCallEvent | ToolResultEvent | ToolStreamEvent]:
|
|
async for event in self._emit_failed_tool_events(resolved.failed_calls):
|
|
yield event
|
|
for tool_call in resolved.tool_calls:
|
|
yield ToolCallEvent(
|
|
tool_name=tool_call.tool_name,
|
|
tool_class=tool_call.tool_class,
|
|
args=tool_call.validated_args,
|
|
tool_call_id=tool_call.call_id,
|
|
)
|
|
async for event in self._process_one_tool_call(tool_call):
|
|
yield event
|
|
|
|
def _handle_tool_response(
|
|
self,
|
|
tool_call: ResolvedToolCall,
|
|
text: str,
|
|
status: Literal["success", "failure", "skipped"],
|
|
decision: ToolDecision | None = None,
|
|
result: dict[str, Any] | None = None,
|
|
) -> None:
|
|
self.messages.append(
|
|
LLMMessage.model_validate(
|
|
self.format_handler.create_tool_response_message(tool_call, text)
|
|
)
|
|
)
|
|
|
|
self.telemetry_client.send_tool_call_finished(
|
|
tool_call=tool_call,
|
|
agent_profile_name=self.agent_profile.name,
|
|
status=status,
|
|
decision=decision,
|
|
result=result,
|
|
)
|
|
|
|
async def _chat(self, max_tokens: int | None = None) -> LLMChunk:
|
|
active_model = self.config.get_active_model()
|
|
provider = self.config.get_provider_for_model(active_model)
|
|
|
|
available_tools = self.format_handler.get_available_tools(self.tool_manager)
|
|
tool_choice = self.format_handler.get_tool_choice()
|
|
|
|
try:
|
|
start_time = time.perf_counter()
|
|
result = await self.backend.complete(
|
|
model=active_model,
|
|
messages=self.messages,
|
|
temperature=active_model.temperature,
|
|
tools=available_tools,
|
|
tool_choice=tool_choice,
|
|
extra_headers=self._get_extra_headers(provider),
|
|
max_tokens=max_tokens,
|
|
)
|
|
end_time = time.perf_counter()
|
|
|
|
if result.usage is None:
|
|
raise AgentLoopLLMResponseError(
|
|
"Usage data missing in non-streaming completion response"
|
|
)
|
|
self._update_stats(usage=result.usage, time_seconds=end_time - start_time)
|
|
|
|
processed_message = self.format_handler.process_api_response_message(
|
|
result.message
|
|
)
|
|
self.messages.append(processed_message)
|
|
return LLMChunk(message=processed_message, usage=result.usage)
|
|
|
|
except Exception as e:
|
|
if _should_raise_rate_limit_error(e):
|
|
raise RateLimitError(provider.name, active_model.name) from e
|
|
|
|
raise RuntimeError(
|
|
f"API error from {provider.name} (model: {active_model.name}): {e}"
|
|
) from e
|
|
|
|
async def _chat_streaming(
|
|
self, max_tokens: int | None = None
|
|
) -> AsyncGenerator[LLMChunk]:
|
|
active_model = self.config.get_active_model()
|
|
provider = self.config.get_provider_for_model(active_model)
|
|
|
|
available_tools = self.format_handler.get_available_tools(self.tool_manager)
|
|
tool_choice = self.format_handler.get_tool_choice()
|
|
try:
|
|
start_time = time.perf_counter()
|
|
usage = LLMUsage()
|
|
chunk_agg = LLMChunk(message=LLMMessage(role=Role.assistant))
|
|
async for chunk in self.backend.complete_streaming(
|
|
model=active_model,
|
|
messages=self.messages,
|
|
temperature=active_model.temperature,
|
|
tools=available_tools,
|
|
tool_choice=tool_choice,
|
|
extra_headers=self._get_extra_headers(provider),
|
|
max_tokens=max_tokens,
|
|
):
|
|
processed_message = self.format_handler.process_api_response_message(
|
|
chunk.message
|
|
)
|
|
processed_chunk = LLMChunk(message=processed_message, usage=chunk.usage)
|
|
chunk_agg += processed_chunk
|
|
usage += chunk.usage or LLMUsage()
|
|
yield processed_chunk
|
|
end_time = time.perf_counter()
|
|
|
|
if chunk_agg.usage is None:
|
|
raise AgentLoopLLMResponseError(
|
|
"Usage data missing in final chunk of streamed completion"
|
|
)
|
|
self._update_stats(usage=usage, time_seconds=end_time - start_time)
|
|
|
|
self.messages.append(chunk_agg.message)
|
|
|
|
except Exception as e:
|
|
if _should_raise_rate_limit_error(e):
|
|
raise RateLimitError(provider.name, active_model.name) from e
|
|
|
|
raise RuntimeError(
|
|
f"API error from {provider.name} (model: {active_model.name}): {e}"
|
|
) from e
|
|
|
|
def _update_stats(self, usage: LLMUsage, time_seconds: float) -> None:
|
|
self.stats.last_turn_duration = time_seconds
|
|
self.stats.last_turn_prompt_tokens = usage.prompt_tokens
|
|
self.stats.last_turn_completion_tokens = usage.completion_tokens
|
|
self.stats.session_prompt_tokens += usage.prompt_tokens
|
|
self.stats.session_completion_tokens += usage.completion_tokens
|
|
self.stats.context_tokens = usage.prompt_tokens + usage.completion_tokens
|
|
if time_seconds > 0 and usage.completion_tokens > 0:
|
|
self.stats.tokens_per_second = usage.completion_tokens / time_seconds
|
|
|
|
async def _should_execute_tool(
|
|
self, tool: BaseTool, args: BaseModel, tool_call_id: str
|
|
) -> ToolDecision:
|
|
if self.auto_approve:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.EXECUTE,
|
|
approval_type=ToolPermission.ALWAYS,
|
|
)
|
|
|
|
allowlist_denylist_result = tool.check_allowlist_denylist(args)
|
|
if allowlist_denylist_result == ToolPermission.ALWAYS:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.EXECUTE,
|
|
approval_type=ToolPermission.ALWAYS,
|
|
)
|
|
elif allowlist_denylist_result == ToolPermission.NEVER:
|
|
denylist_patterns = tool.config.denylist
|
|
denylist_str = ", ".join(repr(pattern) for pattern in denylist_patterns)
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.SKIP,
|
|
approval_type=ToolPermission.NEVER,
|
|
feedback=f"Tool '{tool.get_name()}' blocked by denylist: [{denylist_str}]",
|
|
)
|
|
|
|
tool_name = tool.get_name()
|
|
perm = self.tool_manager.get_tool_config(tool_name).permission
|
|
|
|
if perm is ToolPermission.ALWAYS:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.EXECUTE,
|
|
approval_type=ToolPermission.ALWAYS,
|
|
)
|
|
if perm is ToolPermission.NEVER:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.SKIP,
|
|
approval_type=ToolPermission.NEVER,
|
|
feedback=f"Tool '{tool_name}' is permanently disabled",
|
|
)
|
|
|
|
return await self._ask_approval(tool_name, args, tool_call_id)
|
|
|
|
async def _ask_approval(
|
|
self, tool_name: str, args: BaseModel, tool_call_id: str
|
|
) -> ToolDecision:
|
|
if not self.approval_callback:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.SKIP,
|
|
approval_type=ToolPermission.ASK,
|
|
feedback="Tool execution not permitted.",
|
|
)
|
|
if asyncio.iscoroutinefunction(self.approval_callback):
|
|
async_callback = cast(AsyncApprovalCallback, self.approval_callback)
|
|
response, feedback = await async_callback(tool_name, args, tool_call_id)
|
|
else:
|
|
sync_callback = cast(SyncApprovalCallback, self.approval_callback)
|
|
response, feedback = sync_callback(tool_name, args, tool_call_id)
|
|
|
|
match response:
|
|
case ApprovalResponse.YES:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.EXECUTE,
|
|
approval_type=ToolPermission.ASK,
|
|
feedback=feedback,
|
|
)
|
|
case ApprovalResponse.NO:
|
|
return ToolDecision(
|
|
verdict=ToolExecutionResponse.SKIP,
|
|
approval_type=ToolPermission.ASK,
|
|
feedback=feedback,
|
|
)
|
|
|
|
def _clean_message_history(self) -> None:
|
|
ACCEPTABLE_HISTORY_SIZE = 2
|
|
if len(self.messages) < ACCEPTABLE_HISTORY_SIZE:
|
|
return
|
|
self._fill_missing_tool_responses()
|
|
self._ensure_assistant_after_tools()
|
|
|
|
def _fill_missing_tool_responses(self) -> None:
|
|
i = 1
|
|
while i < len(self.messages): # noqa: PLR1702
|
|
msg = self.messages[i]
|
|
|
|
if msg.role == "assistant" and msg.tool_calls:
|
|
expected_responses = len(msg.tool_calls)
|
|
|
|
if expected_responses > 0:
|
|
actual_responses = 0
|
|
j = i + 1
|
|
while j < len(self.messages) and self.messages[j].role == "tool":
|
|
actual_responses += 1
|
|
j += 1
|
|
|
|
if actual_responses < expected_responses:
|
|
insertion_point = i + 1 + actual_responses
|
|
|
|
for call_idx in range(actual_responses, expected_responses):
|
|
tool_call_data = msg.tool_calls[call_idx]
|
|
|
|
empty_response = LLMMessage(
|
|
role=Role.tool,
|
|
tool_call_id=tool_call_data.id or "",
|
|
name=(tool_call_data.function.name or "")
|
|
if tool_call_data.function
|
|
else "",
|
|
content=str(
|
|
get_user_cancellation_message(
|
|
CancellationReason.TOOL_NO_RESPONSE
|
|
)
|
|
),
|
|
)
|
|
|
|
self.messages.insert(insertion_point, empty_response)
|
|
insertion_point += 1
|
|
|
|
i = i + 1 + expected_responses
|
|
continue
|
|
|
|
i += 1
|
|
|
|
def _ensure_assistant_after_tools(self) -> None:
|
|
MIN_MESSAGE_SIZE = 2
|
|
if len(self.messages) < MIN_MESSAGE_SIZE:
|
|
return
|
|
|
|
last_msg = self.messages[-1]
|
|
if last_msg.role is Role.tool:
|
|
empty_assistant_msg = LLMMessage(role=Role.assistant, content="Understood.")
|
|
self.messages.append(empty_assistant_msg)
|
|
|
|
def _reset_session(self) -> None:
|
|
self.session_id = str(uuid4())
|
|
self.session_logger.reset_session(self.session_id)
|
|
|
|
def set_approval_callback(self, callback: ApprovalCallback) -> None:
|
|
self.approval_callback = callback
|
|
|
|
def set_user_input_callback(self, callback: UserInputCallback) -> None:
|
|
self.user_input_callback = callback
|
|
|
|
async def clear_history(self) -> None:
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
self.messages = self.messages[:1]
|
|
|
|
self.stats = AgentStats()
|
|
self.stats.trigger_listeners()
|
|
|
|
try:
|
|
active_model = self.config.get_active_model()
|
|
self.stats.update_pricing(
|
|
active_model.input_price, active_model.output_price
|
|
)
|
|
except ValueError:
|
|
pass
|
|
|
|
self.middleware_pipeline.reset()
|
|
self.tool_manager.reset_all()
|
|
self._reset_session()
|
|
|
|
async def compact(self) -> str:
|
|
try:
|
|
self._clean_message_history()
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
|
|
summary_request = UtilityPrompt.COMPACT.read()
|
|
self.messages.append(LLMMessage(role=Role.user, content=summary_request))
|
|
self.stats.steps += 1
|
|
|
|
summary_result = await self._chat()
|
|
if summary_result.usage is None:
|
|
raise AgentLoopLLMResponseError(
|
|
"Usage data missing in compaction summary response"
|
|
)
|
|
summary_content = summary_result.message.content or ""
|
|
|
|
system_message = self.messages[0]
|
|
summary_message = LLMMessage(role=Role.user, content=summary_content)
|
|
self.messages = [system_message, summary_message]
|
|
self._last_observed_message_index = 1
|
|
|
|
active_model = self.config.get_active_model()
|
|
provider = self.config.get_provider_for_model(active_model)
|
|
|
|
actual_context_tokens = await self.backend.count_tokens(
|
|
model=active_model,
|
|
messages=self.messages,
|
|
tools=self.format_handler.get_available_tools(self.tool_manager),
|
|
extra_headers={"user-agent": get_user_agent(provider.backend)},
|
|
)
|
|
|
|
self.stats.context_tokens = actual_context_tokens
|
|
|
|
self._reset_session()
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
|
|
self.middleware_pipeline.reset(reset_reason=ResetReason.COMPACT)
|
|
|
|
return summary_content or ""
|
|
|
|
except Exception:
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
raise
|
|
|
|
async def switch_agent(self, agent_name: str) -> None:
|
|
if agent_name == self.agent_profile.name:
|
|
return
|
|
self.agent_manager.switch_profile(agent_name)
|
|
await self.reload_with_initial_messages(reset_middleware=False)
|
|
|
|
async def reload_with_initial_messages(
|
|
self,
|
|
base_config: VibeConfig | None = None,
|
|
max_turns: int | None = None,
|
|
max_price: float | None = None,
|
|
reset_middleware: bool = True,
|
|
) -> None:
|
|
# Force an immediate yield to allow the UI to update before heavy sync work.
|
|
# When there are no messages, save_interaction returns early without any await,
|
|
# so the coroutine would run synchronously through ToolManager, SkillManager,
|
|
# and system prompt generation without yielding control to the event loop.
|
|
await asyncio.sleep(0)
|
|
|
|
await self.session_logger.save_interaction(
|
|
self.messages,
|
|
self.stats,
|
|
self._base_config,
|
|
self.tool_manager,
|
|
self.agent_profile,
|
|
)
|
|
|
|
if base_config is not None:
|
|
self._base_config = base_config
|
|
self.agent_manager.invalidate_config()
|
|
|
|
self.backend = self.backend_factory()
|
|
|
|
if max_turns is not None:
|
|
self._max_turns = max_turns
|
|
if max_price is not None:
|
|
self._max_price = max_price
|
|
|
|
self.tool_manager = ToolManager(lambda: self.config)
|
|
self.skill_manager = SkillManager(lambda: self.config)
|
|
|
|
new_system_prompt = get_universal_system_prompt(
|
|
self.tool_manager, self.config, self.skill_manager, self.agent_manager
|
|
)
|
|
|
|
self.messages = [
|
|
LLMMessage(role=Role.system, content=new_system_prompt),
|
|
*[msg for msg in self.messages if msg.role != Role.system],
|
|
]
|
|
|
|
if len(self.messages) == 1:
|
|
self.stats.reset_context_state()
|
|
|
|
try:
|
|
active_model = self.config.get_active_model()
|
|
self.stats.update_pricing(
|
|
active_model.input_price, active_model.output_price
|
|
)
|
|
except ValueError:
|
|
pass
|
|
|
|
if reset_middleware:
|
|
self._setup_middleware()
|