Files
mistral-vibe/tests/tools/test_ask_user_question.py
Mathias Gesbert e9a9217cc8 v2.7.4 (#579)
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Lucas Marandat <31749711+lucasmrdt@users.noreply.github.com>
Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com>
Co-authored-by: Paul Cacheux <paul.cacheux@mistral.ai>
Co-authored-by: Peter Evers <pevers90@gmail.com>
Co-authored-by: Pierre Rossinès <pierre.rossines@mistral.ai>
Co-authored-by: Pierre Rossinès <pierre.rossines@protonmail.com>
Co-authored-by: Quentin <quentin.torroba@mistral.ai>
Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai>
Co-authored-by: Val <102326092+vdeva@users.noreply.github.com>
Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-04-09 18:40:46 +02:00

189 lines
5.7 KiB
Python

from __future__ import annotations
import pytest
from vibe.core.tools.base import BaseToolState, ToolError
from vibe.core.tools.builtins.ask_user_question import (
Answer,
AskUserQuestion,
AskUserQuestionArgs,
AskUserQuestionConfig,
AskUserQuestionResult,
Choice,
Question,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
@pytest.fixture
def tool():
config = AskUserQuestionConfig()
return AskUserQuestion(config_getter=lambda: config, state=BaseToolState())
@pytest.fixture
def single_question_args():
return AskUserQuestionArgs(
questions=[
Question(
question="Which database?",
header="DB",
options=[
Choice(label="PostgreSQL", description="Relational DB"),
Choice(label="MongoDB", description="Document DB"),
],
)
]
)
@pytest.fixture
def multi_question_args():
return AskUserQuestionArgs(
questions=[
Question(
question="Which database?",
header="DB",
options=[Choice(label="PostgreSQL"), Choice(label="MongoDB")],
),
Question(
question="Which framework?",
header="Framework",
options=[Choice(label="FastAPI"), Choice(label="Django")],
),
]
)
async def run_tool_with_callback(tool, args, callback):
from vibe.core.tools.base import InvokeContext
ctx = InvokeContext(user_input_callback=callback, tool_call_id="123")
result = None
async for item in tool.run(args, ctx):
result = item
return result
@pytest.mark.asyncio
async def test_raises_error_without_callback(tool, single_question_args):
with pytest.raises(ToolError) as err:
async for _ in tool.run(single_question_args, ctx=None):
pass
assert "interactive UI" in str(err.value)
@pytest.mark.asyncio
async def test_calls_callback_and_returns_result(tool, single_question_args):
expected_result = AskUserQuestionResult(
answers=[
Answer(question="Which database?", answer="PostgreSQL", is_other=False)
],
cancelled=False,
)
async def mock_callback(args):
assert args == single_question_args
return expected_result
result = await run_tool_with_callback(tool, single_question_args, mock_callback)
assert result is not None
assert result == expected_result
assert result.answers[0].answer == "PostgreSQL"
@pytest.mark.asyncio
async def test_handles_cancelled_result(tool, single_question_args):
expected_result = AskUserQuestionResult(answers=[], cancelled=True)
async def mock_callback(args):
return expected_result
result = await run_tool_with_callback(tool, single_question_args, mock_callback)
assert result is not None
assert result.cancelled is True
assert len(result.answers) == 0
@pytest.mark.asyncio
async def test_handles_other_response(tool, single_question_args):
expected_result = AskUserQuestionResult(
answers=[Answer(question="Which database?", answer="SQLite", is_other=True)],
cancelled=False,
)
async def mock_callback(args):
return expected_result
result = await run_tool_with_callback(tool, single_question_args, mock_callback)
assert result is not None
assert result.answers[0].is_other is True
assert result.answers[0].answer == "SQLite"
class TestToolUIDisplay:
def test_get_call_display_single_question(self, single_question_args):
event = ToolCallEvent(
tool_name="ask_user_question",
tool_class=AskUserQuestion,
args=single_question_args,
tool_call_id="123",
)
display = AskUserQuestion.get_call_display(event)
assert "Which database?" in display.summary
def test_get_call_display_multiple_questions(self, multi_question_args):
event = ToolCallEvent(
tool_name="ask_user_question",
tool_class=AskUserQuestion,
args=multi_question_args,
tool_call_id="123",
)
display = AskUserQuestion.get_call_display(event)
assert "2 questions" in display.summary
def test_get_result_display_success(self):
result = AskUserQuestionResult(
answers=[Answer(question="Q?", answer="A", is_other=False)], cancelled=False
)
event = ToolResultEvent(
tool_name="ask_user_question",
tool_class=AskUserQuestion,
result=result,
tool_call_id="123",
)
display = AskUserQuestion.get_result_display(event)
assert display.success is True
assert "A" in display.message
def test_get_result_display_cancelled(self):
result = AskUserQuestionResult(answers=[], cancelled=True)
event = ToolResultEvent(
tool_name="ask_user_question",
tool_class=AskUserQuestion,
result=result,
tool_call_id="123",
)
display = AskUserQuestion.get_result_display(event)
assert display.success is False
assert "cancelled" in display.message.lower()
def test_get_result_display_other(self):
result = AskUserQuestionResult(
answers=[Answer(question="Q?", answer="Custom", is_other=True)],
cancelled=False,
)
event = ToolResultEvent(
tool_name="ask_user_question",
tool_class=AskUserQuestion,
result=result,
tool_call_id="123",
)
display = AskUserQuestion.get_result_display(event)
assert display.success is True
assert "(Other)" in display.message