Files
mistral-vibe/vibe/core/middleware.py
Mathias Gesbert ec7f3b25ea v2.2.0 (#395)
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai>
Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai>
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
2026-02-17 16:23:28 +01:00

210 lines
7.3 KiB
Python

from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import StrEnum, auto
from typing import TYPE_CHECKING, Any, Protocol
from vibe.core.agents import AgentProfile
from vibe.core.agents.models import BuiltinAgentName
from vibe.core.utils import VIBE_WARNING_TAG
if TYPE_CHECKING:
from vibe.core.config import VibeConfig
from vibe.core.types import AgentStats, LLMMessage
class MiddlewareAction(StrEnum):
CONTINUE = auto()
STOP = auto()
COMPACT = auto()
INJECT_MESSAGE = auto()
class ResetReason(StrEnum):
STOP = auto()
COMPACT = auto()
@dataclass
class ConversationContext:
messages: list[LLMMessage]
stats: AgentStats
config: VibeConfig
@dataclass
class MiddlewareResult:
action: MiddlewareAction = MiddlewareAction.CONTINUE
message: str | None = None
reason: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
class ConversationMiddleware(Protocol):
async def before_turn(self, context: ConversationContext) -> MiddlewareResult: ...
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None: ...
class TurnLimitMiddleware:
def __init__(self, max_turns: int) -> None:
self.max_turns = max_turns
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
if context.stats.steps - 1 >= self.max_turns:
return MiddlewareResult(
action=MiddlewareAction.STOP,
reason=f"Turn limit of {self.max_turns} reached",
)
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
pass
class PriceLimitMiddleware:
def __init__(self, max_price: float) -> None:
self.max_price = max_price
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
if context.stats.session_cost > self.max_price:
return MiddlewareResult(
action=MiddlewareAction.STOP,
reason=f"Price limit exceeded: ${context.stats.session_cost:.4f} > ${self.max_price:.2f}",
)
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
pass
class AutoCompactMiddleware:
def __init__(self, threshold: int) -> None:
self.threshold = threshold
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
if context.stats.context_tokens >= self.threshold:
return MiddlewareResult(
action=MiddlewareAction.COMPACT,
metadata={
"old_tokens": context.stats.context_tokens,
"threshold": self.threshold,
},
)
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
pass
class ContextWarningMiddleware:
def __init__(
self, threshold_percent: float = 0.5, max_context: int | None = None
) -> None:
self.threshold_percent = threshold_percent
self.max_context = max_context
self.has_warned = False
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
if self.has_warned:
return MiddlewareResult()
max_context = self.max_context
if max_context is None:
return MiddlewareResult()
if context.stats.context_tokens >= max_context * self.threshold_percent:
self.has_warned = True
percentage_used = (context.stats.context_tokens / max_context) * 100
warning_msg = f"<{VIBE_WARNING_TAG}>You have used {percentage_used:.0f}% of your total context ({context.stats.context_tokens:,}/{max_context:,} tokens)</{VIBE_WARNING_TAG}>"
return MiddlewareResult(
action=MiddlewareAction.INJECT_MESSAGE, message=warning_msg
)
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
self.has_warned = False
PLAN_AGENT_REMINDER = f"""<{VIBE_WARNING_TAG}>Plan mode is active. The user indicated that they do not want you to execute yet -- you MUST NOT make any edits, run any non-readonly tools (including changing configs or making commits), or otherwise make any changes to the system. This supersedes any other instructions you have received (for example, to make edits). Instead, you should:
1. Answer the user's query comprehensively
2. When you're done researching, present your plan by giving the full plan and not doing further tool calls to return input to the user. Do NOT make any file changes or run any tools that modify the system state in any way until the user has confirmed the plan.</{VIBE_WARNING_TAG}>"""
PLAN_AGENT_EXIT = f"""<{VIBE_WARNING_TAG}>Plan mode has ended. If you have a plan ready, you can now start executing it. If not, you can now use editing tools and make changes to the system.</{VIBE_WARNING_TAG}>"""
class PlanAgentMiddleware:
def __init__(
self,
profile_getter: Callable[[], AgentProfile],
reminder: str = PLAN_AGENT_REMINDER,
exit_message: str = PLAN_AGENT_EXIT,
) -> None:
self._profile_getter = profile_getter
self.reminder = reminder
self.exit_message = exit_message
self._was_plan_agent = False
def _is_plan_agent(self) -> bool:
return self._profile_getter().name == BuiltinAgentName.PLAN
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
is_plan = self._is_plan_agent()
was_plan = self._was_plan_agent
if was_plan and not is_plan:
self._was_plan_agent = False
return MiddlewareResult(
action=MiddlewareAction.INJECT_MESSAGE, message=self.exit_message
)
if is_plan and not was_plan:
self._was_plan_agent = True
return MiddlewareResult(
action=MiddlewareAction.INJECT_MESSAGE, message=self.reminder
)
self._was_plan_agent = is_plan
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
self._was_plan_agent = False
class MiddlewarePipeline:
def __init__(self) -> None:
self.middlewares: list[ConversationMiddleware] = []
def add(self, middleware: ConversationMiddleware) -> MiddlewarePipeline:
self.middlewares.append(middleware)
return self
def clear(self) -> None:
self.middlewares.clear()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
for mw in self.middlewares:
mw.reset(reset_reason)
async def run_before_turn(self, context: ConversationContext) -> MiddlewareResult:
messages_to_inject = []
for mw in self.middlewares:
result = await mw.before_turn(context)
if result.action == MiddlewareAction.INJECT_MESSAGE and result.message:
messages_to_inject.append(result.message)
elif result.action in {MiddlewareAction.STOP, MiddlewareAction.COMPACT}:
return result
if messages_to_inject:
combined_message = "\n\n".join(messages_to_inject)
return MiddlewareResult(
action=MiddlewareAction.INJECT_MESSAGE, message=combined_message
)
return MiddlewareResult()