mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai> Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai> Co-authored-by: Gauthier Guinet <43207538+Gguinet@users.noreply.github.com> Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai> Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com> Co-authored-by: Quentin <torroba.q@gmail.com> Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com> Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai> Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com> Co-authored-by: angelapopopo <angele.lenglemetz@mistral.ai> Co-authored-by: Mistral Vibe <vibe@mistral.ai>
1002 lines
35 KiB
Python
1002 lines
35 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from acp import (
|
|
InitializeRequest,
|
|
NewSessionRequest,
|
|
PromptRequest,
|
|
ReadTextFileRequest,
|
|
ReadTextFileResponse,
|
|
RequestPermissionRequest,
|
|
RequestPermissionResponse,
|
|
WriteTextFileRequest,
|
|
)
|
|
from acp.schema import (
|
|
AllowedOutcome,
|
|
DeniedOutcome,
|
|
InitializeResponse,
|
|
NewSessionResponse,
|
|
PromptResponse,
|
|
SessionNotification,
|
|
TextContentBlock,
|
|
)
|
|
from pydantic import BaseModel
|
|
import pytest
|
|
import tomli_w
|
|
|
|
from tests import TESTS_ROOT
|
|
from tests.conftest import get_base_config
|
|
from tests.mock.utils import get_mocking_env, mock_llm_chunk
|
|
from vibe.acp.utils import ToolOption
|
|
from vibe.core.types import FunctionCall, ToolCall
|
|
|
|
RESPONSE_TIMEOUT = 2.0
|
|
MOCK_ENTRYPOINT_PATH = "tests/mock/mock_entrypoint.py"
|
|
PLAYGROUND_DIR = TESTS_ROOT / "playground"
|
|
|
|
|
|
def deep_merge(target: dict, source: dict) -> None:
|
|
for key, value in source.items():
|
|
if (
|
|
key in target
|
|
and isinstance(target.get(key), dict)
|
|
and isinstance(value, dict)
|
|
):
|
|
deep_merge(target[key], value)
|
|
elif (
|
|
key in target
|
|
and isinstance(target.get(key), list)
|
|
and isinstance(value, list)
|
|
):
|
|
if key in {"providers", "models"}:
|
|
target[key] = value
|
|
else:
|
|
target[key] = list(set(value + target[key]))
|
|
else:
|
|
target[key] = value
|
|
|
|
|
|
def _create_vibe_home_dir(tmp_path: Path, *sections: dict[str, Any]) -> Path:
|
|
"""Create a temporary vibe home directory with a minimal config file."""
|
|
vibe_home = tmp_path / ".vibe"
|
|
vibe_home.mkdir()
|
|
|
|
config_file = vibe_home / "config.toml"
|
|
base_config_dict = get_base_config()
|
|
|
|
base_config_dict["active_model"] = "devstral-latest"
|
|
if base_config_dict.get("models"):
|
|
for model in base_config_dict["models"]:
|
|
if model.get("name") == "mistral-vibe-cli-latest":
|
|
model["alias"] = "devstral-latest"
|
|
|
|
if sections:
|
|
for section_dict in sections:
|
|
deep_merge(base_config_dict, section_dict)
|
|
|
|
with config_file.open("wb") as f:
|
|
tomli_w.dump(base_config_dict, f)
|
|
|
|
trusted_folters_file = vibe_home / "trusted_folders.toml"
|
|
trusted_folters_file.write_text("trusted = []\nuntrusted = []", encoding="utf-8")
|
|
|
|
return vibe_home
|
|
|
|
|
|
@pytest.fixture
|
|
def vibe_home_dir(tmp_path: Path) -> Path:
|
|
"""Create a temporary vibe home directory with a minimal config file."""
|
|
return _create_vibe_home_dir(tmp_path)
|
|
|
|
|
|
@pytest.fixture
|
|
def vibe_home_grep_ask(tmp_path: Path) -> Path:
|
|
"""Create a temporary vibe home directory with grep configured to ask permission."""
|
|
return _create_vibe_home_dir(tmp_path, {"tools": {"grep": {"permission": "ask"}}})
|
|
|
|
|
|
class JsonRpcRequest(BaseModel):
|
|
jsonrpc: str = "2.0"
|
|
id: int | str
|
|
method: str
|
|
params: Any | None = None
|
|
|
|
|
|
class JsonRpcError(BaseModel):
|
|
code: int
|
|
message: str
|
|
data: Any | None = None
|
|
|
|
|
|
class JsonRpcResponse(BaseModel):
|
|
jsonrpc: str = "2.0"
|
|
id: int | str | None = None
|
|
result: Any | None = None
|
|
error: JsonRpcError | None = None
|
|
|
|
|
|
class JsonRpcNotification(BaseModel):
|
|
jsonrpc: str = "2.0"
|
|
method: str
|
|
params: Any | None = None
|
|
|
|
|
|
type JsonRpcMessage = JsonRpcResponse | JsonRpcNotification | JsonRpcRequest
|
|
|
|
|
|
class InitializeJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "initialize"
|
|
params: InitializeRequest | None = None
|
|
|
|
|
|
class InitializeJsonRpcResponse(JsonRpcResponse):
|
|
result: InitializeResponse | None = None
|
|
|
|
|
|
class NewSessionJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "session/new"
|
|
params: NewSessionRequest | None = None
|
|
|
|
|
|
class NewSessionJsonRpcResponse(JsonRpcResponse):
|
|
result: NewSessionResponse | None = None
|
|
|
|
|
|
class PromptJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "session/prompt"
|
|
params: PromptRequest | None = None
|
|
|
|
|
|
class PromptJsonRpcResponse(JsonRpcResponse):
|
|
result: PromptResponse | None = None
|
|
|
|
|
|
class UpdateJsonRpcNotification(JsonRpcNotification):
|
|
method: str = "session/update"
|
|
params: SessionNotification | None = None
|
|
|
|
|
|
class RequestPermissionJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "session/request_permission"
|
|
params: RequestPermissionRequest | None = None
|
|
|
|
|
|
class RequestPermissionJsonRpcResponse(JsonRpcResponse):
|
|
result: RequestPermissionResponse | None = None
|
|
|
|
|
|
class ReadTextFileJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "fs/read_text_file"
|
|
params: ReadTextFileRequest | None = None
|
|
|
|
|
|
class ReadTextFileJsonRpcResponse(JsonRpcResponse):
|
|
result: ReadTextFileResponse | None = None
|
|
|
|
|
|
class WriteTextFileJsonRpcRequest(JsonRpcRequest):
|
|
method: str = "fs/write_text_file"
|
|
params: WriteTextFileRequest | None = None
|
|
|
|
|
|
class WriteTextFileJsonRpcResponse(JsonRpcResponse):
|
|
result: None = None
|
|
|
|
|
|
async def get_acp_agent_loop_process(
|
|
mock_env: dict[str, str], vibe_home: Path
|
|
) -> AsyncGenerator[asyncio.subprocess.Process]:
|
|
current_env = os.environ.copy()
|
|
cmd = ["uv", "run", MOCK_ENTRYPOINT_PATH]
|
|
|
|
env = dict(current_env)
|
|
env.update(mock_env)
|
|
env["MISTRAL_API_KEY"] = "mock"
|
|
env["VIBE_HOME"] = str(vibe_home)
|
|
|
|
process = await asyncio.create_subprocess_exec(
|
|
*cmd,
|
|
stdin=asyncio.subprocess.PIPE,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
cwd=TESTS_ROOT.parent,
|
|
env=env,
|
|
)
|
|
|
|
try:
|
|
yield process
|
|
finally:
|
|
# Cleanup
|
|
if process.returncode is None:
|
|
process.terminate()
|
|
try:
|
|
await asyncio.wait_for(process.wait(), timeout=0.5)
|
|
except TimeoutError:
|
|
process.kill()
|
|
await process.wait()
|
|
|
|
|
|
async def send_json_rpc(
|
|
process: asyncio.subprocess.Process, message: JsonRpcMessage
|
|
) -> None:
|
|
if process.stdin is None:
|
|
raise RuntimeError("Process stdin not available")
|
|
|
|
request = message.model_dump_json()
|
|
request_json = request + "\n"
|
|
process.stdin.write(request_json.encode())
|
|
await process.stdin.drain()
|
|
|
|
|
|
async def read_response(
|
|
process: asyncio.subprocess.Process, timeout: float = RESPONSE_TIMEOUT
|
|
) -> str | None:
|
|
if process.stdout is None:
|
|
raise RuntimeError("Process stdout not available")
|
|
|
|
try:
|
|
# Keep reading lines until we find a valid JSON line
|
|
while True:
|
|
line = await asyncio.wait_for(process.stdout.readline(), timeout=timeout)
|
|
|
|
if not line:
|
|
return None
|
|
|
|
line_str = line.decode().strip()
|
|
if not line_str:
|
|
continue
|
|
|
|
try:
|
|
json.loads(line_str)
|
|
return line_str
|
|
except json.JSONDecodeError:
|
|
# Not JSON, skip it (it's a log message)
|
|
continue
|
|
except TimeoutError:
|
|
return None
|
|
|
|
|
|
async def read_response_for_id(
|
|
process: asyncio.subprocess.Process,
|
|
expected_id: int | str,
|
|
timeout: float = RESPONSE_TIMEOUT,
|
|
) -> str | None:
|
|
loop = asyncio.get_running_loop()
|
|
end_time = loop.time() + timeout
|
|
|
|
while (remaining := end_time - loop.time()) > 0:
|
|
response = await read_response(process, timeout=remaining)
|
|
if response is None:
|
|
return None
|
|
|
|
response_json = json.loads(response)
|
|
if response_json.get("id") == expected_id:
|
|
return response
|
|
print(
|
|
f"Skipping response with id={response_json.get('id')}, expecting {expected_id}"
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
async def read_multiple_responses(
|
|
process: asyncio.subprocess.Process,
|
|
max_count: int = 10,
|
|
timeout_per_response: float = RESPONSE_TIMEOUT,
|
|
) -> list[str]:
|
|
responses = []
|
|
for _ in range(max_count):
|
|
response = await read_response(process, timeout=timeout_per_response)
|
|
if response:
|
|
responses.append(response)
|
|
else:
|
|
break
|
|
return responses
|
|
|
|
|
|
def parse_conversation(message_texts: list[str]) -> list[JsonRpcMessage]:
|
|
parsed_messages: list[JsonRpcMessage] = []
|
|
for message_text in message_texts:
|
|
message_json = json.loads(message_text)
|
|
cls = None
|
|
has_method = message_json.get("method", None) is not None
|
|
has_id = message_json.get("id", None) is not None
|
|
has_result = message_json.get("result", None) is not None
|
|
|
|
is_request = has_method and has_id
|
|
is_notification = has_method and not has_id
|
|
is_response = has_result
|
|
|
|
if is_request:
|
|
match message_json.get("method"):
|
|
case "session/prompt":
|
|
cls = PromptJsonRpcRequest
|
|
case "session/request_permission":
|
|
cls = RequestPermissionJsonRpcRequest
|
|
case "fs/read_text_file":
|
|
cls = ReadTextFileJsonRpcRequest
|
|
case "fs/write_text_file":
|
|
cls = WriteTextFileJsonRpcRequest
|
|
elif is_notification:
|
|
match message_json.get("method"):
|
|
case "session/update":
|
|
cls = UpdateJsonRpcNotification
|
|
elif is_response:
|
|
# For responses, since we don't know the method, we need to find
|
|
# the matching request.
|
|
matching_request = next(
|
|
(
|
|
m
|
|
for m in parsed_messages
|
|
if isinstance(m, JsonRpcRequest) and m.id == message_json.get("id")
|
|
),
|
|
None,
|
|
)
|
|
if matching_request is None:
|
|
# No matching request found in the conversation, it most probably was
|
|
# not included in the conversation. We use a generic response class.
|
|
cls = JsonRpcResponse
|
|
else:
|
|
match matching_request.method:
|
|
case "session/prompt":
|
|
cls = PromptJsonRpcResponse
|
|
case "session/request_permission":
|
|
cls = RequestPermissionJsonRpcResponse
|
|
case "fs/read_text_file":
|
|
cls = ReadTextFileJsonRpcResponse
|
|
case "fs/write_text_file":
|
|
cls = WriteTextFileJsonRpcResponse
|
|
if cls is None:
|
|
raise ValueError(f"No valid message class found for {message_json}")
|
|
parsed_messages.append(cls.model_validate(message_json))
|
|
return parsed_messages
|
|
|
|
|
|
async def initialize_session(acp_agent_loop_process: asyncio.subprocess.Process) -> str:
|
|
await send_json_rpc(
|
|
acp_agent_loop_process,
|
|
InitializeJsonRpcRequest(id=1, params=InitializeRequest(protocol_version=1)),
|
|
)
|
|
initialize_response = await read_response_for_id(
|
|
acp_agent_loop_process, expected_id=1, timeout=5.0
|
|
)
|
|
assert initialize_response is not None
|
|
|
|
await send_json_rpc(
|
|
acp_agent_loop_process,
|
|
NewSessionJsonRpcRequest(
|
|
id=2, params=NewSessionRequest(cwd=str(PLAYGROUND_DIR), mcp_servers=[])
|
|
),
|
|
)
|
|
session_response = await read_response_for_id(acp_agent_loop_process, expected_id=2)
|
|
assert session_response is not None
|
|
session_response_json = json.loads(session_response)
|
|
session_response_obj = NewSessionJsonRpcResponse.model_validate(
|
|
session_response_json
|
|
)
|
|
assert session_response_obj.result is not None, "No result in response"
|
|
return session_response_obj.result.session_id
|
|
|
|
|
|
class TestSessionManagement:
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_sessions_unique_ids(self, vibe_home_dir: Path) -> None:
|
|
mock_env = get_mocking_env(mock_chunks=[mock_llm_chunk() for _ in range(3)])
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_dir
|
|
):
|
|
await send_json_rpc(
|
|
process,
|
|
InitializeJsonRpcRequest(
|
|
id=1, params=InitializeRequest(protocol_version=1)
|
|
),
|
|
)
|
|
await read_response_for_id(process, expected_id=1, timeout=5.0)
|
|
|
|
session_ids = []
|
|
for i in range(3):
|
|
await send_json_rpc(
|
|
process,
|
|
NewSessionJsonRpcRequest(
|
|
id=i + 2,
|
|
params=NewSessionRequest(
|
|
cwd=str(PLAYGROUND_DIR), mcp_servers=[]
|
|
),
|
|
),
|
|
)
|
|
text_response = await read_response_for_id(
|
|
process, expected_id=i + 2, timeout=RESPONSE_TIMEOUT
|
|
)
|
|
assert text_response is not None
|
|
response_json = json.loads(text_response)
|
|
response = NewSessionJsonRpcResponse.model_validate(response_json)
|
|
assert response.error is None, f"JSON-RPC error: {response.error}"
|
|
assert response.result is not None, "No result in response"
|
|
session_ids.append(response.result.session_id)
|
|
|
|
assert len(set(session_ids)) == 3
|
|
|
|
|
|
class TestSessionUpdates:
|
|
@pytest.mark.asyncio
|
|
async def test_agent_loop_message_chunk_structure(
|
|
self, vibe_home_dir: Path
|
|
) -> None:
|
|
mock_env = get_mocking_env([mock_llm_chunk(content="Hi")])
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_dir
|
|
):
|
|
# Check stderr for error details if process failed
|
|
if process.returncode is not None and process.stderr:
|
|
stderr_data = await process.stderr.read()
|
|
if stderr_data:
|
|
# Log stderr for debugging test failures
|
|
pass # Could add proper logging here if needed
|
|
|
|
session_id = await initialize_session(process)
|
|
|
|
await send_json_rpc(
|
|
process,
|
|
PromptJsonRpcRequest(
|
|
id=3,
|
|
params=PromptRequest(
|
|
session_id=session_id,
|
|
prompt=[TextContentBlock(type="text", text="Just say hi")],
|
|
),
|
|
),
|
|
)
|
|
|
|
commands_response_text = await read_response(process)
|
|
assert commands_response_text is not None
|
|
commands_response = UpdateJsonRpcNotification.model_validate(
|
|
json.loads(commands_response_text)
|
|
)
|
|
assert commands_response.params is not None
|
|
assert (
|
|
commands_response.params.update.session_update
|
|
== "available_commands_update"
|
|
)
|
|
|
|
user_response_text = await read_response(process)
|
|
assert user_response_text is not None
|
|
user_response = UpdateJsonRpcNotification.model_validate(
|
|
json.loads(user_response_text)
|
|
)
|
|
assert user_response.params is not None
|
|
assert user_response.params.update.session_update == "user_message_chunk"
|
|
|
|
text_response = await read_response(process)
|
|
assert text_response is not None
|
|
response = UpdateJsonRpcNotification.model_validate(
|
|
json.loads(text_response)
|
|
)
|
|
|
|
assert response.params is not None
|
|
assert response.params.update.session_update == "agent_message_chunk"
|
|
assert response.params.update.content is not None
|
|
assert isinstance(response.params.update.content, TextContentBlock)
|
|
assert response.params.update.content.type == "text"
|
|
assert response.params.update.content.text is not None
|
|
assert response.params.update.content.text == "Hi"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_update_structure(self, vibe_home_dir: Path) -> None:
|
|
mock_env = get_mocking_env([
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep", arguments='{"pattern": "auth"}'
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(content="The files containing the pattern 'auth' are ..."),
|
|
])
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_dir
|
|
):
|
|
session_id = await initialize_session(process)
|
|
|
|
await send_json_rpc(
|
|
process,
|
|
PromptJsonRpcRequest(
|
|
id=3,
|
|
params=PromptRequest(
|
|
session_id=session_id,
|
|
prompt=[
|
|
TextContentBlock(
|
|
type="text",
|
|
text="Show me files that are related to auth",
|
|
)
|
|
],
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=10)
|
|
assert len(text_responses) > 0
|
|
responses = [
|
|
UpdateJsonRpcNotification.model_validate(json.loads(r))
|
|
for r in text_responses
|
|
]
|
|
|
|
tool_call = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.params is not None
|
|
and r.params.update.session_update == "tool_call"
|
|
),
|
|
None,
|
|
)
|
|
assert tool_call is not None
|
|
assert tool_call.params is not None
|
|
assert tool_call.params.update is not None
|
|
|
|
assert tool_call.params.update.session_update == "tool_call"
|
|
assert tool_call.params.update.kind == "search"
|
|
assert tool_call.params.update.title == "Grepping 'auth'"
|
|
assert (
|
|
tool_call.params.update.raw_input
|
|
== '{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}'
|
|
)
|
|
|
|
|
|
async def start_session_with_request_permission(
|
|
process: asyncio.subprocess.Process, prompt: str
|
|
) -> RequestPermissionJsonRpcRequest:
|
|
session_id = await initialize_session(process)
|
|
await send_json_rpc(
|
|
process,
|
|
PromptJsonRpcRequest(
|
|
id=3,
|
|
params=PromptRequest(
|
|
session_id=session_id,
|
|
prompt=[TextContentBlock(type="text", text=prompt)],
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(
|
|
process, max_count=15, timeout_per_response=2.0
|
|
)
|
|
|
|
responses = parse_conversation(text_responses)
|
|
last_response = responses[-1]
|
|
|
|
assert isinstance(last_response, RequestPermissionJsonRpcRequest)
|
|
assert last_response.params is not None
|
|
assert len(last_response.params.options) == 3
|
|
return last_response
|
|
|
|
|
|
class TestToolCallStructure:
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_request_permission_structure(
|
|
self, vibe_home_grep_ask: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep",
|
|
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
)
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_grep_ask
|
|
):
|
|
session_id = await initialize_session(process)
|
|
await send_json_rpc(
|
|
process,
|
|
PromptJsonRpcRequest(
|
|
id=3,
|
|
params=PromptRequest(
|
|
session_id=session_id,
|
|
prompt=[
|
|
TextContentBlock(
|
|
type="text",
|
|
text="Search for files containing the pattern 'auth'",
|
|
)
|
|
],
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(
|
|
process, max_count=15, timeout_per_response=2.0
|
|
)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
# Look for tool call request permission updates
|
|
permission_requests = [
|
|
r for r in responses if isinstance(r, RequestPermissionJsonRpcRequest)
|
|
]
|
|
|
|
assert len(permission_requests) > 0, (
|
|
f"No tool call permission requests found. Got {len(responses)} responses: "
|
|
f"{[type(r).__name__ for r in responses]}"
|
|
)
|
|
|
|
first_request = permission_requests[0]
|
|
assert first_request.params is not None
|
|
assert first_request.params.tool_call is not None
|
|
assert first_request.params.tool_call.tool_call_id is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_update_approved_structure(
|
|
self, vibe_home_grep_ask: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep",
|
|
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(content="The search for 'auth' has been completed"),
|
|
mock_llm_chunk(content="The file test.txt has been created"),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_grep_ask
|
|
):
|
|
permission_request = await start_session_with_request_permission(
|
|
process, "Search for files containing the pattern 'auth'"
|
|
)
|
|
assert permission_request.params is not None
|
|
selected_option_id = ToolOption.ALLOW_ONCE
|
|
await send_json_rpc(
|
|
process,
|
|
RequestPermissionJsonRpcResponse(
|
|
id=permission_request.id,
|
|
result=RequestPermissionResponse(
|
|
outcome=AllowedOutcome(
|
|
outcome="selected", option_id=selected_option_id
|
|
)
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=7)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
approved_tool_call = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.method == "session/update"
|
|
and r.params is not None
|
|
and r.params.update is not None
|
|
and r.params.update.session_update == "tool_call_update"
|
|
and r.params.update.tool_call_id
|
|
== (permission_request.params.tool_call.tool_call_id)
|
|
and r.params.update.status == "completed"
|
|
),
|
|
None,
|
|
)
|
|
assert approved_tool_call is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_update_rejected_structure(
|
|
self, vibe_home_grep_ask: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep",
|
|
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(
|
|
content="The search for 'auth' has not been performed, "
|
|
"because you rejected the permission request"
|
|
),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_grep_ask
|
|
):
|
|
permission_request = await start_session_with_request_permission(
|
|
process, "Search for files containing the pattern 'auth'"
|
|
)
|
|
assert permission_request.params is not None
|
|
selected_option_id = ToolOption.REJECT_ONCE
|
|
|
|
await send_json_rpc(
|
|
process,
|
|
RequestPermissionJsonRpcResponse(
|
|
id=permission_request.id,
|
|
result=RequestPermissionResponse(
|
|
outcome=AllowedOutcome(
|
|
outcome="selected", option_id=selected_option_id
|
|
)
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=5)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
rejected_tool_call = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.method == "session/update"
|
|
and r.params is not None
|
|
and r.params.update.session_update == "tool_call_update"
|
|
and r.params.update.tool_call_id
|
|
== (permission_request.params.tool_call.tool_call_id)
|
|
and r.params.update.status == "failed"
|
|
),
|
|
None,
|
|
)
|
|
assert rejected_tool_call is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_permission_options_include_granular_labels_for_bash(
|
|
self, vibe_home_dir: Path
|
|
) -> None:
|
|
"""Bash 'npm install foo' should produce granular labels in permission options."""
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="bash", arguments='{"command":"npm install foo"}'
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(content="Done"),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_dir
|
|
):
|
|
permission_request = await start_session_with_request_permission(
|
|
process, "Run npm install foo"
|
|
)
|
|
assert permission_request.params is not None
|
|
|
|
# Verify "Allow always" option includes the pattern label
|
|
allow_always = next(
|
|
o
|
|
for o in permission_request.params.options
|
|
if o.option_id == ToolOption.ALLOW_ALWAYS
|
|
)
|
|
assert "npm install *" in allow_always.name
|
|
|
|
@pytest.mark.skip(reason="Long running tool call updates are not implemented yet")
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_in_progress_update_structure(
|
|
self, vibe_home_grep_ask: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep",
|
|
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(content="The search for 'auth' has been completed"),
|
|
mock_llm_chunk(content="The command sleep 3 has been run"),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_grep_ask
|
|
):
|
|
session_id = await initialize_session(process)
|
|
await send_json_rpc(
|
|
process,
|
|
PromptJsonRpcRequest(
|
|
id=3,
|
|
params=PromptRequest(
|
|
session_id=session_id,
|
|
prompt=[
|
|
TextContentBlock(
|
|
type="text",
|
|
text="Search for files containing the pattern 'auth'",
|
|
)
|
|
],
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=4)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
# Look for tool call in progress updates
|
|
in_progress_calls = [
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.params is not None
|
|
and r.params.update.session_update == "tool_call_update"
|
|
and r.params.update.status == "in_progress"
|
|
]
|
|
|
|
assert len(in_progress_calls) > 0, (
|
|
"No tool call in progress updates found for a long running command"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_result_update_failure_structure(
|
|
self, vibe_home_grep_ask: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="grep",
|
|
arguments='{"pattern":"auth","path":"/nonexistent","max_matches":null,"use_default_ignore":true}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(
|
|
content="The search for 'auth' has failed "
|
|
"because the path does not exist"
|
|
),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_grep_ask
|
|
):
|
|
permission_request = await start_session_with_request_permission(
|
|
process,
|
|
"Search for files containing the pattern 'auth' in /nonexistent",
|
|
)
|
|
assert permission_request.params is not None
|
|
selected_option_id = ToolOption.ALLOW_ONCE
|
|
await send_json_rpc(
|
|
process,
|
|
RequestPermissionJsonRpcResponse(
|
|
id=permission_request.id,
|
|
result=RequestPermissionResponse(
|
|
outcome=AllowedOutcome(
|
|
outcome="selected", option_id=selected_option_id
|
|
)
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=7)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
# Look for tool call result failure updates
|
|
failure_result = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.params is not None
|
|
and r.params.update.session_update == "tool_call_update"
|
|
and r.params.update.status == "failed"
|
|
and r.params.update.raw_output is not None
|
|
and r.params.update.tool_call_id is not None
|
|
),
|
|
None,
|
|
)
|
|
|
|
assert failure_result is not None
|
|
|
|
|
|
class TestCancellationStructure:
|
|
@pytest.mark.skip(
|
|
reason="Proper cancellation is not implemented yet, we still need to return "
|
|
"the right end_turn and be able to cancel at any point in time "
|
|
"(and not only at tool call time)"
|
|
)
|
|
@pytest.mark.asyncio
|
|
async def test_tool_call_update_cancelled_structure(
|
|
self, vibe_home_dir: Path
|
|
) -> None:
|
|
custom_results = [
|
|
mock_llm_chunk(
|
|
tool_calls=[
|
|
ToolCall(
|
|
function=FunctionCall(
|
|
name="write_file",
|
|
arguments='{"path":"test.txt","content":"hello, world!"'
|
|
',"overwrite":false}',
|
|
),
|
|
type="function",
|
|
index=0,
|
|
)
|
|
]
|
|
),
|
|
mock_llm_chunk(
|
|
content="The file test.txt has not been created, "
|
|
"because you cancelled the permission request"
|
|
),
|
|
]
|
|
mock_env = get_mocking_env(custom_results)
|
|
async for process in get_acp_agent_loop_process(
|
|
mock_env=mock_env, vibe_home=vibe_home_dir
|
|
):
|
|
permission_request = await start_session_with_request_permission(
|
|
process, "Create a file named test.txt"
|
|
)
|
|
assert permission_request.params is not None
|
|
|
|
await send_json_rpc(
|
|
process,
|
|
RequestPermissionJsonRpcResponse(
|
|
id=permission_request.id,
|
|
result=RequestPermissionResponse(
|
|
outcome=DeniedOutcome(outcome="cancelled")
|
|
),
|
|
),
|
|
)
|
|
text_responses = await read_multiple_responses(process, max_count=5)
|
|
responses = parse_conversation(text_responses)
|
|
|
|
assert len(responses) == 2, (
|
|
"There should be only 2 responses: "
|
|
"the tool call update and the prompt end turn"
|
|
)
|
|
|
|
cancelled_tool_call = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, UpdateJsonRpcNotification)
|
|
and r.method == "session/update"
|
|
and r.params is not None
|
|
and r.params.update.session_update == "tool_call_update"
|
|
and r.params.update.tool_call_id
|
|
== (permission_request.params.tool_call.tool_call_id)
|
|
and r.params.update.status == "failed"
|
|
),
|
|
None,
|
|
)
|
|
assert cancelled_tool_call is not None
|
|
|
|
cancelled_prompt_response = next(
|
|
(
|
|
r
|
|
for r in responses
|
|
if isinstance(r, PromptJsonRpcResponse)
|
|
and r.result is not None
|
|
and r.result.stop_reason == "cancelled"
|
|
),
|
|
None,
|
|
)
|
|
assert cancelled_prompt_response is not None
|