Files
mistral-vibe/tests/tools/test_invoke_context.py
Mathias Gesbert e9a9217cc8 v2.7.4 (#579)
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Lucas Marandat <31749711+lucasmrdt@users.noreply.github.com>
Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com>
Co-authored-by: Paul Cacheux <paul.cacheux@mistral.ai>
Co-authored-by: Peter Evers <pevers90@gmail.com>
Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai>
Co-authored-by: Pierre Rossinès <pierre.rossines@protonmail.com>
Co-authored-by: Quentin <quentin.torroba@mistral.ai>
Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai>
Co-authored-by: Val <102326092+vdeva@users.noreply.github.com>
Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-04-09 18:40:46 +02:00

113 lines
3.8 KiB
Python

from __future__ import annotations
from collections.abc import AsyncGenerator
from pydantic import BaseModel
import pytest
from tests.mock.utils import collect_result
from vibe.core.tools.base import BaseTool, BaseToolConfig, BaseToolState, InvokeContext
from vibe.core.types import ApprovalCallback, ApprovalResponse, ToolStreamEvent
class SimpleArgs(BaseModel):
value: str
class SimpleResult(BaseModel):
output: str
had_context: bool
approval_callback_present: bool
class SimpleTool(BaseTool[SimpleArgs, SimpleResult, BaseToolConfig, BaseToolState]):
description = "A simple test tool"
async def run(
self, args: SimpleArgs, ctx: InvokeContext | None = None
) -> AsyncGenerator[ToolStreamEvent | SimpleResult, None]:
yield SimpleResult(
output=f"processed: {args.value}",
had_context=ctx is not None,
approval_callback_present=ctx is not None
and ctx.approval_callback is not None,
)
@pytest.fixture
def simple_tool() -> SimpleTool:
return SimpleTool(config_getter=lambda: BaseToolConfig(), state=BaseToolState())
class TestInvokeContext:
def test_default_approval_callback_is_none(self) -> None:
ctx = InvokeContext(tool_call_id="test-call-id")
assert ctx.approval_callback is None
def test_approval_callback_can_be_set(self) -> None:
async def dummy_callback(
_tool_name: str,
_args: BaseModel,
_tool_call_id: str,
_rp: list | None = None,
) -> tuple[ApprovalResponse, str | None]:
return ApprovalResponse.YES, None
callback: ApprovalCallback = dummy_callback
ctx = InvokeContext(tool_call_id="test-call-id", approval_callback=callback)
assert ctx.approval_callback is callback
class TestToolInvokeWithContext:
@pytest.mark.asyncio
async def test_invoke_without_context(self, simple_tool: SimpleTool) -> None:
result = await collect_result(simple_tool.invoke(value="test"))
assert result.output == "processed: test"
assert result.had_context is False
assert result.approval_callback_present is False
@pytest.mark.asyncio
async def test_invoke_with_empty_context(self, simple_tool: SimpleTool) -> None:
ctx = InvokeContext(tool_call_id="test-call-id")
result = await collect_result(simple_tool.invoke(ctx=ctx, value="test"))
assert result.output == "processed: test"
assert result.had_context is True
assert result.approval_callback_present is False
@pytest.mark.asyncio
async def test_invoke_with_approval_callback(self, simple_tool: SimpleTool) -> None:
async def dummy_callback(
_tool_name: str,
_args: BaseModel,
_tool_call_id: str,
_rp: list | None = None,
) -> tuple[ApprovalResponse, str | None]:
return ApprovalResponse.YES, None
callback: ApprovalCallback = dummy_callback
ctx = InvokeContext(tool_call_id="test-call-id", approval_callback=callback)
result = await collect_result(simple_tool.invoke(ctx=ctx, value="test"))
assert result.output == "processed: test"
assert result.had_context is True
assert result.approval_callback_present is True
@pytest.mark.asyncio
async def test_run_receives_context(self, simple_tool: SimpleTool) -> None:
ctx = InvokeContext(tool_call_id="test-call-id")
result = await collect_result(
simple_tool.run(SimpleArgs(value="direct"), ctx=ctx)
)
assert result.had_context is True
@pytest.mark.asyncio
async def test_run_without_context_defaults_to_none(
self, simple_tool: SimpleTool
) -> None:
result = await collect_result(simple_tool.run(SimpleArgs(value="direct")))
assert result.had_context is False