Files
mistral-vibe/tests/acp/test_acp.py
Mathias Gesbert eb580209d4 v2.6.0 (#524)
Co-authored-by: Clément Drouin <clement.drouin@mistral.ai>
Co-authored-by: Clément Sirieix <clement.sirieix@mistral.ai>
Co-authored-by: Gauthier Guinet <43207538+Gguinet@users.noreply.github.com>
Co-authored-by: Kim-Adeline Miguel <kimadeline.miguel@mistral.ai>
Co-authored-by: Michel Thomazo <51709227+michelTho@users.noreply.github.com>
Co-authored-by: Quentin <torroba.q@gmail.com>
Co-authored-by: Simon <80467011+sorgfresser@users.noreply.github.com>
Co-authored-by: Simon Van de Kerckhove <simon.vandekerckhove@mistral.ai>
Co-authored-by: Vincent G <10739306+VinceOPS@users.noreply.github.com>
Co-authored-by: angelapopopo <angele.lenglemetz@mistral.ai>
Co-authored-by: Mistral Vibe <vibe@mistral.ai>
2026-03-23 18:45:21 +01:00

1002 lines
35 KiB
Python

from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
import json
import os
from pathlib import Path
from typing import Any
from acp import (
InitializeRequest,
NewSessionRequest,
PromptRequest,
ReadTextFileRequest,
ReadTextFileResponse,
RequestPermissionRequest,
RequestPermissionResponse,
WriteTextFileRequest,
)
from acp.schema import (
AllowedOutcome,
DeniedOutcome,
InitializeResponse,
NewSessionResponse,
PromptResponse,
SessionNotification,
TextContentBlock,
)
from pydantic import BaseModel
import pytest
import tomli_w
from tests import TESTS_ROOT
from tests.conftest import get_base_config
from tests.mock.utils import get_mocking_env, mock_llm_chunk
from vibe.acp.utils import ToolOption
from vibe.core.types import FunctionCall, ToolCall
RESPONSE_TIMEOUT = 2.0
MOCK_ENTRYPOINT_PATH = "tests/mock/mock_entrypoint.py"
PLAYGROUND_DIR = TESTS_ROOT / "playground"
def deep_merge(target: dict, source: dict) -> None:
for key, value in source.items():
if (
key in target
and isinstance(target.get(key), dict)
and isinstance(value, dict)
):
deep_merge(target[key], value)
elif (
key in target
and isinstance(target.get(key), list)
and isinstance(value, list)
):
if key in {"providers", "models"}:
target[key] = value
else:
target[key] = list(set(value + target[key]))
else:
target[key] = value
def _create_vibe_home_dir(tmp_path: Path, *sections: dict[str, Any]) -> Path:
"""Create a temporary vibe home directory with a minimal config file."""
vibe_home = tmp_path / ".vibe"
vibe_home.mkdir()
config_file = vibe_home / "config.toml"
base_config_dict = get_base_config()
base_config_dict["active_model"] = "devstral-latest"
if base_config_dict.get("models"):
for model in base_config_dict["models"]:
if model.get("name") == "mistral-vibe-cli-latest":
model["alias"] = "devstral-latest"
if sections:
for section_dict in sections:
deep_merge(base_config_dict, section_dict)
with config_file.open("wb") as f:
tomli_w.dump(base_config_dict, f)
trusted_folters_file = vibe_home / "trusted_folders.toml"
trusted_folters_file.write_text("trusted = []\nuntrusted = []", encoding="utf-8")
return vibe_home
@pytest.fixture
def vibe_home_dir(tmp_path: Path) -> Path:
"""Create a temporary vibe home directory with a minimal config file."""
return _create_vibe_home_dir(tmp_path)
@pytest.fixture
def vibe_home_grep_ask(tmp_path: Path) -> Path:
"""Create a temporary vibe home directory with grep configured to ask permission."""
return _create_vibe_home_dir(tmp_path, {"tools": {"grep": {"permission": "ask"}}})
class JsonRpcRequest(BaseModel):
jsonrpc: str = "2.0"
id: int | str
method: str
params: Any | None = None
class JsonRpcError(BaseModel):
code: int
message: str
data: Any | None = None
class JsonRpcResponse(BaseModel):
jsonrpc: str = "2.0"
id: int | str | None = None
result: Any | None = None
error: JsonRpcError | None = None
class JsonRpcNotification(BaseModel):
jsonrpc: str = "2.0"
method: str
params: Any | None = None
type JsonRpcMessage = JsonRpcResponse | JsonRpcNotification | JsonRpcRequest
class InitializeJsonRpcRequest(JsonRpcRequest):
method: str = "initialize"
params: InitializeRequest | None = None
class InitializeJsonRpcResponse(JsonRpcResponse):
result: InitializeResponse | None = None
class NewSessionJsonRpcRequest(JsonRpcRequest):
method: str = "session/new"
params: NewSessionRequest | None = None
class NewSessionJsonRpcResponse(JsonRpcResponse):
result: NewSessionResponse | None = None
class PromptJsonRpcRequest(JsonRpcRequest):
method: str = "session/prompt"
params: PromptRequest | None = None
class PromptJsonRpcResponse(JsonRpcResponse):
result: PromptResponse | None = None
class UpdateJsonRpcNotification(JsonRpcNotification):
method: str = "session/update"
params: SessionNotification | None = None
class RequestPermissionJsonRpcRequest(JsonRpcRequest):
method: str = "session/request_permission"
params: RequestPermissionRequest | None = None
class RequestPermissionJsonRpcResponse(JsonRpcResponse):
result: RequestPermissionResponse | None = None
class ReadTextFileJsonRpcRequest(JsonRpcRequest):
method: str = "fs/read_text_file"
params: ReadTextFileRequest | None = None
class ReadTextFileJsonRpcResponse(JsonRpcResponse):
result: ReadTextFileResponse | None = None
class WriteTextFileJsonRpcRequest(JsonRpcRequest):
method: str = "fs/write_text_file"
params: WriteTextFileRequest | None = None
class WriteTextFileJsonRpcResponse(JsonRpcResponse):
result: None = None
async def get_acp_agent_loop_process(
mock_env: dict[str, str], vibe_home: Path
) -> AsyncGenerator[asyncio.subprocess.Process]:
current_env = os.environ.copy()
cmd = ["uv", "run", MOCK_ENTRYPOINT_PATH]
env = dict(current_env)
env.update(mock_env)
env["MISTRAL_API_KEY"] = "mock"
env["VIBE_HOME"] = str(vibe_home)
process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=TESTS_ROOT.parent,
env=env,
)
try:
yield process
finally:
# Cleanup
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=0.5)
except TimeoutError:
process.kill()
await process.wait()
async def send_json_rpc(
process: asyncio.subprocess.Process, message: JsonRpcMessage
) -> None:
if process.stdin is None:
raise RuntimeError("Process stdin not available")
request = message.model_dump_json()
request_json = request + "\n"
process.stdin.write(request_json.encode())
await process.stdin.drain()
async def read_response(
process: asyncio.subprocess.Process, timeout: float = RESPONSE_TIMEOUT
) -> str | None:
if process.stdout is None:
raise RuntimeError("Process stdout not available")
try:
# Keep reading lines until we find a valid JSON line
while True:
line = await asyncio.wait_for(process.stdout.readline(), timeout=timeout)
if not line:
return None
line_str = line.decode().strip()
if not line_str:
continue
try:
json.loads(line_str)
return line_str
except json.JSONDecodeError:
# Not JSON, skip it (it's a log message)
continue
except TimeoutError:
return None
async def read_response_for_id(
process: asyncio.subprocess.Process,
expected_id: int | str,
timeout: float = RESPONSE_TIMEOUT,
) -> str | None:
loop = asyncio.get_running_loop()
end_time = loop.time() + timeout
while (remaining := end_time - loop.time()) > 0:
response = await read_response(process, timeout=remaining)
if response is None:
return None
response_json = json.loads(response)
if response_json.get("id") == expected_id:
return response
print(
f"Skipping response with id={response_json.get('id')}, expecting {expected_id}"
)
return None
async def read_multiple_responses(
process: asyncio.subprocess.Process,
max_count: int = 10,
timeout_per_response: float = RESPONSE_TIMEOUT,
) -> list[str]:
responses = []
for _ in range(max_count):
response = await read_response(process, timeout=timeout_per_response)
if response:
responses.append(response)
else:
break
return responses
def parse_conversation(message_texts: list[str]) -> list[JsonRpcMessage]:
parsed_messages: list[JsonRpcMessage] = []
for message_text in message_texts:
message_json = json.loads(message_text)
cls = None
has_method = message_json.get("method", None) is not None
has_id = message_json.get("id", None) is not None
has_result = message_json.get("result", None) is not None
is_request = has_method and has_id
is_notification = has_method and not has_id
is_response = has_result
if is_request:
match message_json.get("method"):
case "session/prompt":
cls = PromptJsonRpcRequest
case "session/request_permission":
cls = RequestPermissionJsonRpcRequest
case "fs/read_text_file":
cls = ReadTextFileJsonRpcRequest
case "fs/write_text_file":
cls = WriteTextFileJsonRpcRequest
elif is_notification:
match message_json.get("method"):
case "session/update":
cls = UpdateJsonRpcNotification
elif is_response:
# For responses, since we don't know the method, we need to find
# the matching request.
matching_request = next(
(
m
for m in parsed_messages
if isinstance(m, JsonRpcRequest) and m.id == message_json.get("id")
),
None,
)
if matching_request is None:
# No matching request found in the conversation, it most probably was
# not included in the conversation. We use a generic response class.
cls = JsonRpcResponse
else:
match matching_request.method:
case "session/prompt":
cls = PromptJsonRpcResponse
case "session/request_permission":
cls = RequestPermissionJsonRpcResponse
case "fs/read_text_file":
cls = ReadTextFileJsonRpcResponse
case "fs/write_text_file":
cls = WriteTextFileJsonRpcResponse
if cls is None:
raise ValueError(f"No valid message class found for {message_json}")
parsed_messages.append(cls.model_validate(message_json))
return parsed_messages
async def initialize_session(acp_agent_loop_process: asyncio.subprocess.Process) -> str:
await send_json_rpc(
acp_agent_loop_process,
InitializeJsonRpcRequest(id=1, params=InitializeRequest(protocol_version=1)),
)
initialize_response = await read_response_for_id(
acp_agent_loop_process, expected_id=1, timeout=5.0
)
assert initialize_response is not None
await send_json_rpc(
acp_agent_loop_process,
NewSessionJsonRpcRequest(
id=2, params=NewSessionRequest(cwd=str(PLAYGROUND_DIR), mcp_servers=[])
),
)
session_response = await read_response_for_id(acp_agent_loop_process, expected_id=2)
assert session_response is not None
session_response_json = json.loads(session_response)
session_response_obj = NewSessionJsonRpcResponse.model_validate(
session_response_json
)
assert session_response_obj.result is not None, "No result in response"
return session_response_obj.result.session_id
class TestSessionManagement:
@pytest.mark.asyncio
async def test_multiple_sessions_unique_ids(self, vibe_home_dir: Path) -> None:
mock_env = get_mocking_env(mock_chunks=[mock_llm_chunk() for _ in range(3)])
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_dir
):
await send_json_rpc(
process,
InitializeJsonRpcRequest(
id=1, params=InitializeRequest(protocol_version=1)
),
)
await read_response_for_id(process, expected_id=1, timeout=5.0)
session_ids = []
for i in range(3):
await send_json_rpc(
process,
NewSessionJsonRpcRequest(
id=i + 2,
params=NewSessionRequest(
cwd=str(PLAYGROUND_DIR), mcp_servers=[]
),
),
)
text_response = await read_response_for_id(
process, expected_id=i + 2, timeout=RESPONSE_TIMEOUT
)
assert text_response is not None
response_json = json.loads(text_response)
response = NewSessionJsonRpcResponse.model_validate(response_json)
assert response.error is None, f"JSON-RPC error: {response.error}"
assert response.result is not None, "No result in response"
session_ids.append(response.result.session_id)
assert len(set(session_ids)) == 3
class TestSessionUpdates:
@pytest.mark.asyncio
async def test_agent_loop_message_chunk_structure(
self, vibe_home_dir: Path
) -> None:
mock_env = get_mocking_env([mock_llm_chunk(content="Hi")])
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_dir
):
# Check stderr for error details if process failed
if process.returncode is not None and process.stderr:
stderr_data = await process.stderr.read()
if stderr_data:
# Log stderr for debugging test failures
pass # Could add proper logging here if needed
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
session_id=session_id,
prompt=[TextContentBlock(type="text", text="Just say hi")],
),
),
)
commands_response_text = await read_response(process)
assert commands_response_text is not None
commands_response = UpdateJsonRpcNotification.model_validate(
json.loads(commands_response_text)
)
assert commands_response.params is not None
assert (
commands_response.params.update.session_update
== "available_commands_update"
)
user_response_text = await read_response(process)
assert user_response_text is not None
user_response = UpdateJsonRpcNotification.model_validate(
json.loads(user_response_text)
)
assert user_response.params is not None
assert user_response.params.update.session_update == "user_message_chunk"
text_response = await read_response(process)
assert text_response is not None
response = UpdateJsonRpcNotification.model_validate(
json.loads(text_response)
)
assert response.params is not None
assert response.params.update.session_update == "agent_message_chunk"
assert response.params.update.content is not None
assert isinstance(response.params.update.content, TextContentBlock)
assert response.params.update.content.type == "text"
assert response.params.update.content.text is not None
assert response.params.update.content.text == "Hi"
@pytest.mark.asyncio
async def test_tool_call_update_structure(self, vibe_home_dir: Path) -> None:
mock_env = get_mocking_env([
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep", arguments='{"pattern": "auth"}'
),
type="function",
index=0,
)
]
),
mock_llm_chunk(content="The files containing the pattern 'auth' are ..."),
])
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_dir
):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
session_id=session_id,
prompt=[
TextContentBlock(
type="text",
text="Show me files that are related to auth",
)
],
),
),
)
text_responses = await read_multiple_responses(process, max_count=10)
assert len(text_responses) > 0
responses = [
UpdateJsonRpcNotification.model_validate(json.loads(r))
for r in text_responses
]
tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.session_update == "tool_call"
),
None,
)
assert tool_call is not None
assert tool_call.params is not None
assert tool_call.params.update is not None
assert tool_call.params.update.session_update == "tool_call"
assert tool_call.params.update.kind == "search"
assert tool_call.params.update.title == "Grepping 'auth'"
assert (
tool_call.params.update.raw_input
== '{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}'
)
async def start_session_with_request_permission(
process: asyncio.subprocess.Process, prompt: str
) -> RequestPermissionJsonRpcRequest:
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
session_id=session_id,
prompt=[TextContentBlock(type="text", text=prompt)],
),
),
)
text_responses = await read_multiple_responses(
process, max_count=15, timeout_per_response=2.0
)
responses = parse_conversation(text_responses)
last_response = responses[-1]
assert isinstance(last_response, RequestPermissionJsonRpcRequest)
assert last_response.params is not None
assert len(last_response.params.options) == 3
return last_response
class TestToolCallStructure:
@pytest.mark.asyncio
async def test_tool_call_request_permission_structure(
self, vibe_home_grep_ask: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep",
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
),
type="function",
index=0,
)
]
)
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_grep_ask
):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
session_id=session_id,
prompt=[
TextContentBlock(
type="text",
text="Search for files containing the pattern 'auth'",
)
],
),
),
)
text_responses = await read_multiple_responses(
process, max_count=15, timeout_per_response=2.0
)
responses = parse_conversation(text_responses)
# Look for tool call request permission updates
permission_requests = [
r for r in responses if isinstance(r, RequestPermissionJsonRpcRequest)
]
assert len(permission_requests) > 0, (
f"No tool call permission requests found. Got {len(responses)} responses: "
f"{[type(r).__name__ for r in responses]}"
)
first_request = permission_requests[0]
assert first_request.params is not None
assert first_request.params.tool_call is not None
assert first_request.params.tool_call.tool_call_id is not None
@pytest.mark.asyncio
async def test_tool_call_update_approved_structure(
self, vibe_home_grep_ask: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep",
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
),
type="function",
index=0,
)
]
),
mock_llm_chunk(content="The search for 'auth' has been completed"),
mock_llm_chunk(content="The file test.txt has been created"),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_grep_ask
):
permission_request = await start_session_with_request_permission(
process, "Search for files containing the pattern 'auth'"
)
assert permission_request.params is not None
selected_option_id = ToolOption.ALLOW_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", option_id=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=7)
responses = parse_conversation(text_responses)
approved_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update is not None
and r.params.update.session_update == "tool_call_update"
and r.params.update.tool_call_id
== (permission_request.params.tool_call.tool_call_id)
and r.params.update.status == "completed"
),
None,
)
assert approved_tool_call is not None
@pytest.mark.asyncio
async def test_tool_call_update_rejected_structure(
self, vibe_home_grep_ask: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep",
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
),
type="function",
index=0,
)
]
),
mock_llm_chunk(
content="The search for 'auth' has not been performed, "
"because you rejected the permission request"
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_grep_ask
):
permission_request = await start_session_with_request_permission(
process, "Search for files containing the pattern 'auth'"
)
assert permission_request.params is not None
selected_option_id = ToolOption.REJECT_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", option_id=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=5)
responses = parse_conversation(text_responses)
rejected_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update.session_update == "tool_call_update"
and r.params.update.tool_call_id
== (permission_request.params.tool_call.tool_call_id)
and r.params.update.status == "failed"
),
None,
)
assert rejected_tool_call is not None
@pytest.mark.asyncio
async def test_permission_options_include_granular_labels_for_bash(
self, vibe_home_dir: Path
) -> None:
"""Bash 'npm install foo' should produce granular labels in permission options."""
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="bash", arguments='{"command":"npm install foo"}'
),
type="function",
index=0,
)
]
),
mock_llm_chunk(content="Done"),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_dir
):
permission_request = await start_session_with_request_permission(
process, "Run npm install foo"
)
assert permission_request.params is not None
# Verify "Allow always" option includes the pattern label
allow_always = next(
o
for o in permission_request.params.options
if o.option_id == ToolOption.ALLOW_ALWAYS
)
assert "npm install *" in allow_always.name
@pytest.mark.skip(reason="Long running tool call updates are not implemented yet")
@pytest.mark.asyncio
async def test_tool_call_in_progress_update_structure(
self, vibe_home_grep_ask: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep",
arguments='{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}',
),
type="function",
index=0,
)
]
),
mock_llm_chunk(content="The search for 'auth' has been completed"),
mock_llm_chunk(content="The command sleep 3 has been run"),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_grep_ask
):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
session_id=session_id,
prompt=[
TextContentBlock(
type="text",
text="Search for files containing the pattern 'auth'",
)
],
),
),
)
text_responses = await read_multiple_responses(process, max_count=4)
responses = parse_conversation(text_responses)
# Look for tool call in progress updates
in_progress_calls = [
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.session_update == "tool_call_update"
and r.params.update.status == "in_progress"
]
assert len(in_progress_calls) > 0, (
"No tool call in progress updates found for a long running command"
)
@pytest.mark.asyncio
async def test_tool_call_result_update_failure_structure(
self, vibe_home_grep_ask: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep",
arguments='{"pattern":"auth","path":"/nonexistent","max_matches":null,"use_default_ignore":true}',
),
type="function",
index=0,
)
]
),
mock_llm_chunk(
content="The search for 'auth' has failed "
"because the path does not exist"
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_grep_ask
):
permission_request = await start_session_with_request_permission(
process,
"Search for files containing the pattern 'auth' in /nonexistent",
)
assert permission_request.params is not None
selected_option_id = ToolOption.ALLOW_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", option_id=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=7)
responses = parse_conversation(text_responses)
# Look for tool call result failure updates
failure_result = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.session_update == "tool_call_update"
and r.params.update.status == "failed"
and r.params.update.raw_output is not None
and r.params.update.tool_call_id is not None
),
None,
)
assert failure_result is not None
class TestCancellationStructure:
@pytest.mark.skip(
reason="Proper cancellation is not implemented yet, we still need to return "
"the right end_turn and be able to cancel at any point in time "
"(and not only at tool call time)"
)
@pytest.mark.asyncio
async def test_tool_call_update_cancelled_structure(
self, vibe_home_dir: Path
) -> None:
custom_results = [
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"test.txt","content":"hello, world!"'
',"overwrite":false}',
),
type="function",
index=0,
)
]
),
mock_llm_chunk(
content="The file test.txt has not been created, "
"because you cancelled the permission request"
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_loop_process(
mock_env=mock_env, vibe_home=vibe_home_dir
):
permission_request = await start_session_with_request_permission(
process, "Create a file named test.txt"
)
assert permission_request.params is not None
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=DeniedOutcome(outcome="cancelled")
),
),
)
text_responses = await read_multiple_responses(process, max_count=5)
responses = parse_conversation(text_responses)
assert len(responses) == 2, (
"There should be only 2 responses: "
"the tool call update and the prompt end turn"
)
cancelled_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update.session_update == "tool_call_update"
and r.params.update.tool_call_id
== (permission_request.params.tool_call.tool_call_id)
and r.params.update.status == "failed"
),
None,
)
assert cancelled_tool_call is not None
cancelled_prompt_response = next(
(
r
for r in responses
if isinstance(r, PromptJsonRpcResponse)
and r.result is not None
and r.result.stop_reason == "cancelled"
),
None,
)
assert cancelled_prompt_response is not None