mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Quentin Torroba <quentin.torroba@mistral.ai> Co-authored-by: Clément Siriex <clement.sirieix@mistral.ai> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <michel.thomazo@mistral.ai> Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
210 lines
7.3 KiB
Python
210 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from enum import StrEnum, auto
|
|
from typing import TYPE_CHECKING, Any, Protocol
|
|
|
|
from vibe.core.agents import AgentProfile
|
|
from vibe.core.agents.models import BuiltinAgentName
|
|
from vibe.core.utils import VIBE_WARNING_TAG
|
|
|
|
if TYPE_CHECKING:
|
|
from vibe.core.config import VibeConfig
|
|
from vibe.core.types import AgentStats, LLMMessage
|
|
|
|
|
|
class MiddlewareAction(StrEnum):
|
|
CONTINUE = auto()
|
|
STOP = auto()
|
|
COMPACT = auto()
|
|
INJECT_MESSAGE = auto()
|
|
|
|
|
|
class ResetReason(StrEnum):
|
|
STOP = auto()
|
|
COMPACT = auto()
|
|
|
|
|
|
@dataclass
|
|
class ConversationContext:
|
|
messages: list[LLMMessage]
|
|
stats: AgentStats
|
|
config: VibeConfig
|
|
|
|
|
|
@dataclass
|
|
class MiddlewareResult:
|
|
action: MiddlewareAction = MiddlewareAction.CONTINUE
|
|
message: str | None = None
|
|
reason: str | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
class ConversationMiddleware(Protocol):
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult: ...
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None: ...
|
|
|
|
|
|
class TurnLimitMiddleware:
|
|
def __init__(self, max_turns: int) -> None:
|
|
self.max_turns = max_turns
|
|
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
if context.stats.steps - 1 >= self.max_turns:
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.STOP,
|
|
reason=f"Turn limit of {self.max_turns} reached",
|
|
)
|
|
return MiddlewareResult()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
pass
|
|
|
|
|
|
class PriceLimitMiddleware:
|
|
def __init__(self, max_price: float) -> None:
|
|
self.max_price = max_price
|
|
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
if context.stats.session_cost > self.max_price:
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.STOP,
|
|
reason=f"Price limit exceeded: ${context.stats.session_cost:.4f} > ${self.max_price:.2f}",
|
|
)
|
|
return MiddlewareResult()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
pass
|
|
|
|
|
|
class AutoCompactMiddleware:
|
|
def __init__(self, threshold: int) -> None:
|
|
self.threshold = threshold
|
|
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
if context.stats.context_tokens >= self.threshold:
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.COMPACT,
|
|
metadata={
|
|
"old_tokens": context.stats.context_tokens,
|
|
"threshold": self.threshold,
|
|
},
|
|
)
|
|
return MiddlewareResult()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
pass
|
|
|
|
|
|
class ContextWarningMiddleware:
|
|
def __init__(
|
|
self, threshold_percent: float = 0.5, max_context: int | None = None
|
|
) -> None:
|
|
self.threshold_percent = threshold_percent
|
|
self.max_context = max_context
|
|
self.has_warned = False
|
|
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
if self.has_warned:
|
|
return MiddlewareResult()
|
|
|
|
max_context = self.max_context
|
|
if max_context is None:
|
|
return MiddlewareResult()
|
|
|
|
if context.stats.context_tokens >= max_context * self.threshold_percent:
|
|
self.has_warned = True
|
|
|
|
percentage_used = (context.stats.context_tokens / max_context) * 100
|
|
warning_msg = f"<{VIBE_WARNING_TAG}>You have used {percentage_used:.0f}% of your total context ({context.stats.context_tokens:,}/{max_context:,} tokens)</{VIBE_WARNING_TAG}>"
|
|
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.INJECT_MESSAGE, message=warning_msg
|
|
)
|
|
|
|
return MiddlewareResult()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
self.has_warned = False
|
|
|
|
|
|
PLAN_AGENT_REMINDER = f"""<{VIBE_WARNING_TAG}>Plan mode is active. The user indicated that they do not want you to execute yet -- you MUST NOT make any edits, run any non-readonly tools (including changing configs or making commits), or otherwise make any changes to the system. This supersedes any other instructions you have received (for example, to make edits). Instead, you should:
|
|
1. Answer the user's query comprehensively
|
|
2. When you're done researching, present your plan by giving the full plan and not doing further tool calls to return input to the user. Do NOT make any file changes or run any tools that modify the system state in any way until the user has confirmed the plan.</{VIBE_WARNING_TAG}>"""
|
|
|
|
PLAN_AGENT_EXIT = f"""<{VIBE_WARNING_TAG}>Plan mode has ended. If you have a plan ready, you can now start executing it. If not, you can now use editing tools and make changes to the system.</{VIBE_WARNING_TAG}>"""
|
|
|
|
|
|
class PlanAgentMiddleware:
|
|
def __init__(
|
|
self,
|
|
profile_getter: Callable[[], AgentProfile],
|
|
reminder: str = PLAN_AGENT_REMINDER,
|
|
exit_message: str = PLAN_AGENT_EXIT,
|
|
) -> None:
|
|
self._profile_getter = profile_getter
|
|
self.reminder = reminder
|
|
self.exit_message = exit_message
|
|
self._was_plan_agent = False
|
|
|
|
def _is_plan_agent(self) -> bool:
|
|
return self._profile_getter().name == BuiltinAgentName.PLAN
|
|
|
|
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
is_plan = self._is_plan_agent()
|
|
was_plan = self._was_plan_agent
|
|
|
|
if was_plan and not is_plan:
|
|
self._was_plan_agent = False
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.INJECT_MESSAGE, message=self.exit_message
|
|
)
|
|
|
|
if is_plan and not was_plan:
|
|
self._was_plan_agent = True
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.INJECT_MESSAGE, message=self.reminder
|
|
)
|
|
|
|
self._was_plan_agent = is_plan
|
|
|
|
return MiddlewareResult()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
self._was_plan_agent = False
|
|
|
|
|
|
class MiddlewarePipeline:
|
|
def __init__(self) -> None:
|
|
self.middlewares: list[ConversationMiddleware] = []
|
|
|
|
def add(self, middleware: ConversationMiddleware) -> MiddlewarePipeline:
|
|
self.middlewares.append(middleware)
|
|
return self
|
|
|
|
def clear(self) -> None:
|
|
self.middlewares.clear()
|
|
|
|
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
|
for mw in self.middlewares:
|
|
mw.reset(reset_reason)
|
|
|
|
async def run_before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
|
messages_to_inject = []
|
|
|
|
for mw in self.middlewares:
|
|
result = await mw.before_turn(context)
|
|
if result.action == MiddlewareAction.INJECT_MESSAGE and result.message:
|
|
messages_to_inject.append(result.message)
|
|
elif result.action in {MiddlewareAction.STOP, MiddlewareAction.COMPACT}:
|
|
return result
|
|
if messages_to_inject:
|
|
combined_message = "\n\n".join(messages_to_inject)
|
|
return MiddlewareResult(
|
|
action=MiddlewareAction.INJECT_MESSAGE, message=combined_message
|
|
)
|
|
|
|
return MiddlewareResult()
|