Initial commit

Co-Authored-By: Quentin Torroba <quentin.torroba@mistral.ai>
Co-Authored-By: Laure Hugo <laure.hugo@mistral.ai>
Co-Authored-By: Benjamin Trom <benjamin.trom@mistral.ai>
Co-Authored-By: Mathias Gesbert <mathias.gesbert@ext.mistral.ai>
Co-Authored-By: Michel Thomazo <michel.thomazo@mistral.ai>
Co-Authored-By: Clément Drouin <clement.drouin@mistral.ai>
Co-Authored-By: Vincent Guilloux <vincent.guilloux@mistral.ai>
Co-Authored-By: Valentin Berard <val@mistral.ai>
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
Quentin Torroba
2025-12-09 13:13:22 +01:00
committed by Quentin Torroba
commit fa15fc977b
200 changed files with 30484 additions and 0 deletions

12
.envrc Normal file
View File

@@ -0,0 +1,12 @@
# shellcheck shell=bash
if ! has nix_direnv_version || ! nix_direnv_version 3.1.0; then
source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/3.1.0/direnvrc" "sha256-yMJ2OVMzrFaDPn7q8nCBZFRYpL/f0RcHzhmw/i6btJM="
fi
if command -v nix &>/dev/null; then
use flake
fi
if command -v pre-commit &>/dev/null; then
pre-commit install
fi

7
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1,7 @@
# CODEOWNERS
# Default owners for everything in the repo
* @mistralai/mistral-vibe
# Owners for specific directories
# Not needed for now, can be filled later

86
.github/workflows/build-and-upload.yml vendored Normal file
View File

@@ -0,0 +1,86 @@
name: Build and upload
on:
workflow_dispatch:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
build-and-upload:
name: ${{ matrix.os }}-${{ matrix.arch }}
strategy:
matrix:
include:
# Linux
- runner: ubuntu-22.04
os: linux
arch: x86_64
# - runner: ubuntu-22.04-arm # ubuntu-22.04-arm, ubuntu-24.04-arm and windows-11-arm are not supported yet for private repositories
# os: linux
# arch: aarch64
# macOS
- runner: macos-15-intel
os: darwin
arch: x86_64
- runner: macos-14
os: darwin
arch: aarch64
# Windows
- runner: windows-2022
os: windows
arch: x86_64
# - runner: windows-11-arm # ubuntu-22.04-arm, ubuntu-24.04-arm and windows-11-arm are not supported yet for private repositories
# os: windows
# arch: aarch64
runs-on: ${{ matrix.runner }}
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
- name: Install uv with caching
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: "3.12"
- name: Sync dependencies
run: uv sync --all-extras
- name: Build with PyInstaller
run: uv run pyinstaller vibe-acp.spec
- name: Get package version with uv (Unix)
id: get_version_unix
if: ${{ matrix.os != 'windows' }}
run: python -c "import subprocess; version = subprocess.check_output(['uv', 'version']).decode().split()[1]; print(f'version={version}')" >> $GITHUB_OUTPUT
- name: Get package version with uv (Windows)
id: get_version_windows
if: ${{ matrix.os == 'windows' }}
shell: pwsh
run: python -c "import subprocess; version = subprocess.check_output(['uv', 'version']).decode().split()[1]; print(f'version={version}')" >> $env:GITHUB_OUTPUT
- name: Upload binary as artifact (Unix)
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
if: ${{ matrix.os != 'windows' }}
with:
name: vibe-acp-${{ matrix.os }}-${{ matrix.arch }}-${{ steps.get_version_unix.outputs.version }}
path: dist/vibe-acp
- name: Upload binary as artifact (Windows)
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
if: ${{ matrix.os == 'windows' }}
with:
name: vibe-acp-${{ matrix.os }}-${{ matrix.arch }}-${{ steps.get_version_windows.outputs.version }}
path: dist/vibe-acp.exe

126
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,126 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
env:
PYTHON_VERSION: "3.12"
jobs:
pre-commit:
name: Pre-commit
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
- name: Install uv with caching
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Sync dependencies
run: uv sync --all-extras
- name: Install pip (required by pre-commit)
run: uv pip install pip
- name: Add virtual environment to PATH
run: echo "$(pwd)/.venv/bin" >> $GITHUB_PATH
- name: Cache pre-commit
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
with:
path: ~/.cache/pre-commit
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
- name: Run pre-commit
run: uv run pre-commit run --all-files --show-diff-on-failure
tests:
name: Tests
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
- name: Install uv with caching
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Sync dependencies
run: uv sync --all-extras
- name: Verify CLI can start
run: |
uv run vibe --help
uv run vibe-acp --help
- name: Install ripgrep
run: sudo apt-get update && sudo apt-get install -y ripgrep
- name: Run tests
run: uv run pytest --ignore tests/snapshots
snapshot-tests:
name: Snapshot Tests
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
- name: Install uv with caching
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
with:
version: "latest"
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Sync dependencies
run: uv sync --all-extras
- name: Run snapshot tests
id: snapshot-tests
run: uv run pytest tests/snapshots
continue-on-error: true
- name: Upload snapshot report
if: steps.snapshot-tests.outcome == 'failure'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
with:
name: snapshot-report
path: snapshot_report.html
if-no-files-found: warn
retention-days: 3
- name: Fail job if snapshot tests failed
if: steps.snapshot-tests.outcome == 'failure'
run: |
echo "Snapshot tests failed, failing job."
exit 1

44
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,44 @@
name: Release to Pipy
on:
release:
types: [published]
workflow_dispatch:
jobs:
release:
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/mistral-vibe
permissions:
id-token: write
contents: read
steps:
- name: Checkout code
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
- name: Set up Python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: "3.12"
- name: Install uv
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
- name: Install dependencies
run: uv sync --locked --dev
- name: Build package
run: uv build
- name: Upload artifacts
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
with:
name: dist
path: dist/
- name: Publish distribution to PyPI
if: github.repository == 'mistralai/mistral-vibe'
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0

201
.gitignore vendored Normal file
View File

@@ -0,0 +1,201 @@
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.*cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# nix / direnv
.direnv/
result
result-*
# Vibe runtime/session files; keep tools as part of repo when needed
.vibe/*
# Tests run the agent in the playground, we don't need to keep the session files
tests/playground/*
.

34
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,34 @@
---
repos:
- repo: https://github.com/mpalmer/action-validator
rev: v0.7.0
hooks:
- id: action-validator
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
hooks:
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
exclude: tests/snapshots/.*\.svg$
- repo: https://github.com/fsouza/mirrors-pyright
rev: v1.1.407
hooks:
- id: pyright
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.5
hooks:
- id: ruff-check
args: [--fix, --unsafe-fixes]
- id: ruff-format
args: [--check]
- repo: https://github.com/crate-ci/typos
rev: v1.34.0
hooks:
- id: typos
args: [--write-changes]

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12

2
.typos.toml Normal file
View File

@@ -0,0 +1,2 @@
[default]
extend-ignore-re = ["(?m)^.*(#|//)\\s*typos:disable-line$", "datas"]

3
.vscode/extensions.json vendored Normal file
View File

@@ -0,0 +1,3 @@
{
"recommendations": ["ms-python.python", "charliermarsh.ruff"]
}

59
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,59 @@
{
"version": "0.1.0",
"configurations": [
{
"name": "ACP Server",
"type": "debugpy",
"request": "launch",
"program": "vibe/acp/entrypoint.py",
"args": ["--workdir", "${workspaceFolder}"],
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "Tests",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": ["-v", "-s"],
"console": "integratedTerminal",
"justMyCode": false,
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}"
}
},
{
"name": "Single Test",
"type": "debugpy",
"request": "launch",
"module": "pytest",
"args": ["-k", "${input:test_identifier}", "-vvv", "-s", "--no-header"],
"console": "integratedTerminal",
"justMyCode": false,
"cwd": "${workspaceFolder}",
"env": {
"PYTHONPATH": "${workspaceFolder}"
},
"stopOnEntry": false,
"subProcess": true
},
{
"name": "CLI",
"type": "debugpy",
"request": "launch",
"program": "vibe/cli/entrypoint.py",
"args": [],
"console": "integratedTerminal",
"justMyCode": false
}
],
"inputs": [
{
"id": "test_identifier",
"description": "Enter the test identifier (file, class, or function)",
"default": "TestInitialization",
"type": "promptString"
}
]
}

26
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,26 @@
{
"[python]": {
"editor.codeActionsOnSave": {
"source.fixAll.ruff": "explicit",
"source.organizeImports.ruff": "explicit"
},
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true
},
"cursorpyright.analysis.typeCheckingMode": "strict",
"editor.formatOnSave": true,
"files.exclude": {
".pytest_cache/**": true,
".venv/**": true,
"**/__pycache__": true,
"dist/**": true,
"build/**": true
},
"files.insertFinalNewline": true,
"files.trimTrailingWhitespace": true,
"python.analysis.typeCheckingMode": "strict",
"python.testing.pytestArgs": ["tests"],
"python.testing.pytestEnabled": true,
"ruff.enable": true,
"ruff.organizeImports": true
}

135
AGENTS.md Normal file
View File

@@ -0,0 +1,135 @@
# python312.rule
# Rule for enforcing modern Python 3.12+ best practices.
# Applies to all Python files (*.py) in the project.
#
# Guidelines covered:
# - Use match-case syntax instead of if/elif/else for pattern matching.
# - Use the walrus operator (:=) when it simplifies assignments and tests.
# - Favor a "never nester" approach by avoiding deep nesting with early returns or guard clauses.
# - Employ modern type hints using built-in generics (list, dict) and the union pipe (|) operator,
# rather than deprecated types from the typing module (e.g., Optional, Union, Dict, List).
# - Ensure code adheres to strong static typing practices compatible with static analyzers like pyright.
# - Favor pathlib.Path methods for file system operations over older os.path functions.
# - Write code in a declarative and minimalist style that clearly expresses its intent.
# - Additional best practices including f-string formatting, comprehensions, context managers, and overall PEP 8 compliance.
description: "Modern Python 3.12+ best practices and style guidelines for coding."
files: "**/*.py"
guidelines:
- title: "Match-Case Syntax"
description: >
Prefer using the match-case construct over traditional if/elif/else chains when pattern matching
is applicable. This leads to clearer, more concise, and more maintainable code.
- title: "Walrus Operator"
description: >
Utilize the walrus operator (:=) to streamline code where assignment and conditional testing can be combined.
Use it judiciously when it improves readability and reduces redundancy.
- title: "Never Nester"
description: >
Aim to keep code flat by avoiding deep nesting. Use early returns, guard clauses, and refactoring to
minimize nested structures, making your code more readable and maintainable.
- title: "Modern Type Hints"
description: >
Adopt modern type hinting by using built-in generics like list and dict, along with the pipe (|) operator
for union types (e.g., int | None). Avoid older, deprecated constructs such as Optional, Union, Dict, and List
from the typing module.
- title: "Strong Static Typing"
description: >
Write code with explicit and robust type annotations that are fully compatible with static type checkers
like pyright. This ensures higher code reliability and easier maintenance.
- title: "Pydantic-First Parsing"
description: >
Prefer Pydantic v2's native validation over ad-hoc parsing. Use `model_validate`,
`field_validator`, `from_attributes`, and field aliases to coerce external SDK/DTO objects.
Avoid manual `getattr`/`hasattr` flows and custom constructors like `from_sdk` unless they are
thin wrappers over `model_validate`. Keep normalization logic inside model validators so call sites
remain declarative and typed.
- title: "Pathlib for File Operations"
description: >
Favor the use of pathlib.Path methods for file system operations. This approach offers a more
readable, object-oriented way to handle file paths and enhances cross-platform compatibility,
reducing reliance on legacy os.path functions.
- title: "Declarative and Minimalist Code"
description: >
Write code that is declarative—clearly expressing its intentions rather than focusing on implementation details.
Strive to keep your code minimalist by removing unnecessary complexity and boilerplate. This approach improves
readability, maintainability, and aligns with modern Python practices.
- title: "Additional Best Practices"
description: >
Embrace other modern Python idioms such as:
- Using f-strings for string formatting.
- Favoring comprehensions for building lists and dictionaries.
- Employing context managers (with statements) for resource management.
- Following PEP 8 guidelines to maintain overall code style consistency.
- title: "Exception Documentation"
description: >
Document exceptions accurately and minimally in docstrings:
- Only document exceptions that are explicitly raised in the function implementation
- Remove Raises entries for exceptions that are not directly raised
- Include all possible exceptions from explicit raise statements
- For public APIs, document exceptions from called functions if they are allowed to propagate
- Avoid documenting built-in exceptions that are obvious (like TypeError from type hints)
This ensures documentation stays accurate and maintainable, avoiding the common pitfall
of listing every possible exception that could theoretically occur.
- title: "Modern Enum Usage"
description: >
Leverage Python's enum module effectively following modern practices:
- Use StrEnum for string-based enums that need string comparison
- Use IntEnum/IntFlag for performance-critical integer-based enums
- Use auto() for automatic value assignment to maintain clean code
- Always use UPPERCASE for enum members to avoid name clashes
- Add methods to enums when behavior needs to be associated with values
- Use @property for computed attributes rather than storing values
- For type mixing, ensure mix-in types appear before Enum in base class sequence
- Consider Flag/IntFlag for bit field operations
- Use _generate_next_value_ for custom value generation
- Implement __bool__ when enum boolean evaluation should depend on value
This promotes type-safe constants, self-documenting code, and maintainable value sets.
- title: "No Inline Ignores"
description: >
Do not use inline suppressions like `# type: ignore[...]` or `# noqa[...]` in production code.
Instead, fix types and lint warnings at the source by:
- Refining signatures with generics (TypeVar), Protocols, or precise return types
- Guarding with `isinstance` checks before attribute access
- Using `typing.cast` when control flow guarantees the type
- Extracting small helpers to create clearer, typed boundaries
If a suppression is truly unavoidable at an external boundary, prefer a narrow, well-typed wrapper
over in-line ignores, and document the rationale in code comments.
- title: "Pydantic Discriminated Unions"
description: >
When modeling variants with a discriminated union (e.g., on a `transport` field), do not narrow a
field type in a subclass (e.g., overriding `transport: Literal['http']` with `Literal['streamable-http']`).
This violates Liskov substitution and triggers type checker errors due to invariance of class attributes.
Prefer sibling classes plus a shared mixin for common fields and helpers, and compose the union with
`Annotated[Union[...], Field(discriminator='transport')]`.
Example pattern:
- Create a base with shared non-discriminator fields (e.g., `_MCPBase`).
- Create a mixin with protocol-specific fields/methods (e.g., `_MCPHttpFields`), without a `transport`.
- Define sibling final classes per variant (e.g., `MCPHttp`, `MCPStreamableHttp`, `MCPStdio`) that set
`transport: Literal[...]` once in each final class.
- Use `match` on the discriminator to narrow types at call sites.
- title: "Use uv for All Commands"
description: >
We use uv to manage our python environment. You should nevery try to run a bare python commands.
Always run commands using `uv` instead of invoking `python` or `pip` directly.
For example, use `uv add package` and `uv run script.py` rather than `pip install package` or `python script.py`.
This practice helps avoid environment drift and leverages modern Python packaging best practices.
Useful uv commands are:
- uv add/remove <package> to manage dependencies
- uv sync to install dependencies declared in pyproject.toml and uv.lock
- uv run script.py to run a script within the uv environment
- uv run pytest (or any other python tool) to run the tool within the uv environment

10
CHANGELOG.md Normal file
View File

@@ -0,0 +1,10 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Initial release

168
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,168 @@
# Contributing to Mistral Vibe
Thank you for your interest in Mistral Vibe! We appreciate your enthusiasm and support.
## Current Status
**Mistral Vibe is in active development** — our team is iterating quickly and making lots of changes under the hood. Because of this pace, we may be slower than usual when reviewing PRs and issues.
**We especially encourage**:
- **Bug reports** Help us uncover and squash issues
- **Feedback & ideas** Tell us what works, what doesn't, and what could be even better
- **Documentation improvements** Suggest clarity improvements or highlight missing pieces
## How to Provide Feedback
### Bug Reports
If you encounter a bug, please open an issue with the following information:
1. **Description**: A clear description of the bug
2. **Steps to Reproduce**: Detailed steps to reproduce the issue
3. **Expected Behavior**: What you expected to happen
4. **Actual Behavior**: What actually happened
5. **Environment**:
- Python version
- Operating system
- Vibe version
6. **Error Messages**: Any error messages or stack traces
7. **Configuration**: Relevant parts of your `config.toml` (redact any sensitive information)
### Feature Requests and Feedback
We'd love to hear your ideas! When submitting feedback or feature requests:
1. **Clear Description**: Explain what you'd like to see or improve
2. **Use Case**: Describe your use case and why this would be valuable
3. **Alternatives**: If applicable, mention any alternatives you've considered
## Development Setup
This section is for developers who want to set up the repository for local development, even though we're not currently accepting contributions.
### Prerequisites
- Python 3.12 or higher
- [uv](https://github.com/astral-sh/uv) - Modern Python package manager
### Setup
1. Clone the repository:
```bash
git clone <repository-url>
cd mistral-vibe
```
2. Install dependencies:
```bash
uv sync --all-extras
```
This will install both runtime and development dependencies.
3. (Optional) Install pre-commit hooks:
```bash
uv run pre-commit install
```
Pre-commit hooks will automatically run checks before each commit.
### Running Tests
Run all tests:
```bash
uv run pytest
```
Run tests with verbose output:
```bash
uv run pytest -v
```
Run a specific test file:
```bash
uv run pytest tests/test_agent_tool_call.py
```
### Linting and Type Checking
#### Ruff (Linting and Formatting)
Check for linting issues (without fixing):
```bash
uv run ruff check .
```
Auto-fix linting issues:
```bash
uv run ruff check --fix .
```
Format code:
```bash
uv run ruff format .
```
Check formatting without modifying files (useful for CI):
```bash
uv run ruff format --check .
```
#### Pyright (Type Checking)
Run type checking:
```bash
uv run pyright
```
#### Pre-commit Hooks
Run all pre-commit hooks manually:
```bash
uv run pre-commit run --all-files
```
The pre-commit hooks include:
- Ruff (linting and formatting)
- Pyright (type checking)
- Typos (spell checking)
- YAML/TOML validation
- Action validator (for GitHub Actions)
### Code Style
- **Line length**: 88 characters (Black-compatible)
- **Type hints**: Required for all functions and methods
- **Docstrings**: Follow Google-style docstrings
- **Formatting**: Use Ruff for both linting and formatting
- **Type checking**: Use Pyright (configured in `pyproject.toml`)
See `pyproject.toml` for detailed configuration of Ruff and Pyright.
## Code Contributions
While we're not accepting code contributions at the moment, we may open up contributions in the future. When that happens, we'll update this document with:
- Pull request process
- Contribution guidelines
- Review process
## Questions?
If you have questions about using Mistral Vibe, please check the [README](README.md) first. For other inquiries, feel free to open a discussion or issue.
Thank you for helping make Mistral Vibe better! 🙏

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2025 Mistral AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

308
README.md Normal file
View File

@@ -0,0 +1,308 @@
# Mistral Vibe
[![PyPI Version](https://img.shields.io/pypi/v/mistral-vibe)](https://pypi.org/project/mistral-vibe)
[![Python Version](https://img.shields.io/badge/python-3.12%2B-blue)](https://www.python.org/downloads/release/python-3120/)
[![CI Status](https://github.com/mistralai/mistral-vibe/actions/workflows/ci.yml/badge.svg)](https://github.com/mistralai/mistral-vibe/actions/workflows/ci.yml)
[![License](https://img.shields.io/github/license/mistralai/mistral-vibe)](https://github.com/mistralai/mistral-vibe/blob/main/LICENSE)
```
██████████████████░░
██████████████████░░
████ ██████ ████░░
████ ██ ████░░
████ ████░░
████ ██ ██ ████░░
██ ██ ██░░
██████████████████░░
██████████████████░░
```
**Mistral's open-source CLI coding assistant.**
Mistral Vibe is a command-line coding assistant powered by Mistral's models. It provides a conversational interface to your codebase, allowing you to use natural language to explore, modify, and interact with your projects through a powerful set of tools.
> [!WARNING]
> Mistral Vibe works on Windows, but we officially support and target UNIX environments.
## Installation
Vibe requires Python 3.12 or higher.
### One-line install (recommended)
```bash
# On Linux and macOS
curl -LsSf https://mistral.ai/vibe/install.sh | bash
```
### Using uv
```bash
uv tool install mistral-vibe
```
### Using pip
```bash
pip install mistral-vibe
```
## Features
- **Interactive Chat**: A conversational AI agent that understands your requests and breaks down complex tasks.
- **Powerful Toolset**: A suite of tools for file manipulation, code searching, version control, and command execution, right from the chat prompt.
- Read, write, and patch files (`read_file`, `write_file`, `search_replace`).
- Execute shell commands in a stateful terminal (`bash`).
- Recursively search code with `grep` (with `ripgrep` support).
- Manage a `todo` list to track the agent's work.
- **Project-Aware Context**: Vibe automatically scans your project's file structure and Git status to provide relevant context to the agent, improving its understanding of your codebase.
- **Advanced CLI Experience**: Built with modern libraries for a smooth and efficient workflow.
- Autocompletion for slash commands (`/`) and file paths (`@`).
- Persistent command history.
- Beautiful Themes.
- **Highly Configurable**: Customize models, providers, tool permissions, and UI preferences through a simple `config.toml` file.
- **Safety First**: Features tool execution approval.
## Quick Start
1. Navigate to your project's root directory:
```bash
cd /path/to/your/project
```
2. Run Vibe:
```bash
vibe
```
3. If this is your first time running Vibe, it will:
- Create a default configuration file at `~/.vibe/config.toml`
- Prompt you to enter your API key if it's not already configured
- Save your API key to `~/.vibe/.env` for future use
4. Start interacting with the agent!
```
> Can you find all instances of the word "TODO" in the project?
🤖 The user wants to find all instances of "TODO". The `grep` tool is perfect for this. I will use it to search the current directory.
> grep(pattern="TODO", path=".")
... (grep tool output) ...
🤖 I found the following "TODO" comments in your project.
```
## Usage
### Interactive Mode
Simply run `vibe` to enter the interactive chat loop.
- **Multi-line Input**: Press `Ctrl+J` or `Shift+Enter` for select terminals to insert a newline.
- **File Paths**: Reference files in your prompt using the `@` symbol for smart autocompletion (e.g., `> Read the file @src/agent.py`).
- **Shell Commands**: Prefix any command with `!` to execute it directly in your shell, bypassing the agent (e.g., `> !ls -l`).
You can start Vibe with a prompt with the following command:
```bash
vibe "Refactor the main function in cli/main.py to be more modular."
```
**Note**: The `--auto-approve` flag automatically approves all tool executions without prompting. In interactive mode, you can also toggle auto-approve on/off using `Shift+Tab`.
### Programmatic Mode
You can run Vibe non-interactively by piping input or using the `--prompt` flag. This is useful for scripting.
```bash
vibe --prompt "Refactor the main function in cli/main.py to be more modular."
```
by default it will use `auto-approve` mode.
### Slash Commands
Use slash commands for meta-actions and configuration changes during a session.
## Configuration
Vibe is configured via a `config.toml` file. It looks for this file first in `./.vibe/config.toml` and then falls back to `~/.vibe/config.toml`.
### API Key Configuration
Vibe supports multiple ways to configure your API keys:
1. **Interactive Setup (Recommended for first-time users)**: When you run Vibe for the first time or if your API key is missing, Vibe will prompt you to enter it. The key will be securely saved to `~/.vibe/.env` for future sessions.
2. **Environment Variables**: Set your API key as an environment variable:
```bash
export MISTRAL_API_KEY="your_mistral_api_key"
```
3. **`.env` File**: Create a `.env` file in `~/.vibe/` and add your API keys:
```bash
MISTRAL_API_KEY=your_mistral_api_key
```
Vibe automatically loads API keys from `~/.vibe/.env` on startup. Environment variables take precedence over the `.env` file if both are set.
**Note**: The `.env` file is specifically for API keys and other provider credentials. General Vibe configuration should be done in `config.toml`.
### Custom System Prompts
You can create custom system prompts to replace the default one (`prompts/core.md`). Create a markdown file in the `~/.vibe/prompts/` directory with your custom prompt content.
To use a custom system prompt, set the `system_prompt_id` in your configuration to match the filename (without the `.md` extension):
```toml
# Use a custom system prompt
system_prompt_id = "my_custom_prompt"
```
This will load the prompt from `~/.vibe/prompts/my_custom_prompt.md`.
### Custom Agent Configurations
You can create custom agent configurations for specific use cases (e.g., red-teaming, specialized tasks) by adding agent-specific TOML files in the `~/.vibe/agents/` directory.
To use a custom agent, run Vibe with the `--agent` flag:
```bash
vibe --agent my_custom_agent
```
Vibe will look for a file named `my_custom_agent.toml` in the agents directory and apply its configuration.
Example custom agent configuration (`~/.vibe/agents/redteam.toml`):
```toml
# Custom agent configuration for red-teaming
active_model = "devstral-2"
system_prompt_id = "redteam"
# Disable some tools for this agent
disabled_tools = ["search_replace", "write_file"]
# Override tool permissions for this agent
[tools.bash]
permission = "always"
[tools.read_file]
permission = "always"
```
Note: this implies that you have setup a redteam prompt names `~/.vibe/prompts/redteam.md`
### MCP Server Configuration
You can configure MCP (Model Context Protocol) servers to extend Vibe's capabilities. Add MCP server configurations under the `mcp_servers` section:
```toml
# Example MCP server configurations
[[mcp_servers]]
name = "my_http_server"
transport = "http"
url = "http://localhost:8000"
headers = { "Authorization" = "Bearer my_token" }
api_key_env = "MY_API_KEY_ENV_VAR"
api_key_header = "Authorization"
api_key_format = "Bearer {token}"
[[mcp_servers]]
name = "my_streamable_server"
transport = "streamable-http"
url = "http://localhost:8001"
headers = { "X-API-Key" = "my_api_key" }
[[mcp_servers]]
name = "fetch_server"
transport = "stdio"
command = "uvx"
args = ["mcp-server-fetch"]
```
Supported transports:
- `http`: Standard HTTP transport
- `streamable-http`: HTTP transport with streaming support
- `stdio`: Standard input/output transport (for local processes)
Key fields:
- `name`: A short alias for the server (used in tool names)
- `transport`: The transport type
- `url`: Base URL for HTTP transports
- `headers`: Additional HTTP headers
- `api_key_env`: Environment variable containing the API key
- `command`: Command to run for stdio transport
- `args`: Additional arguments for stdio transport
### Enable/disable tools with patterns
You can control which tools are active using `enabled_tools` and `disabled_tools`.
These fields support exact names, glob patterns, and regular expressions.
Examples:
```toml
# Only enable tools that start with "serena_" (glob)
enabled_tools = ["serena_*"]
# Regex (prefix with re:) — matches full tool name (case-insensitive)
enabled_tools = ["re:^serena_.*$"]
# Heuristic regex support (patterns like `serena.*` are treated as regex)
enabled_tools = ["serena.*"]
# Disable a group with glob; everything else stays enabled
disabled_tools = ["mcp_*", "grep"]
```
Notes:
- MCP tool names use underscores, e.g., `serena_list` not `serena.list`.
- Regex patterns are matched against the full tool name using fullmatch.
### Custom Vibe Home Directory
By default, Vibe stores its configuration in `~/.vibe/`. You can override this by setting the `VIBE_HOME` environment variable:
```bash
export VIBE_HOME="/path/to/custom/vibe/home"
```
This affects where Vibe looks for:
- `config.toml` - Main configuration
- `.env` - API keys
- `agents/` - Custom agent configurations
- `prompts/` - Custom system prompts
- `tools/` - Custom tools
- `logs/` - Session logsRetryTo run code, enable code execution and file creation in Settings > Capabilities.
## Resources
- [CHANGELOG](CHANGELOG.md) - See what's new in each version
- [CONTRIBUTING](CONTRIBUTING.md) - Guidelines for feedback and bug reports
## License
Copyright 2025 Mistral AI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the [LICENSE](LICENSE) file for the full license text.

64
action.yml Normal file
View File

@@ -0,0 +1,64 @@
---
name: Mistral Vibe
description: "Download, install, and run Mistral Vibe"
author: Mistral AI
inputs:
prompt:
description: The prompt to pass to the agent
required: true
default: |
You are a helpful assistant
MISTRAL_API_KEY:
description: API Key for AI Studio
required: true
install_python:
description: |
Whether or not to install Python
required: true
default: "true"
python_version:
description: |
Version of Python to install. Warning: Unsupported.
required: false
runs:
using: "composite"
steps:
- name: Install Required Python Version
if: inputs.install_python == true
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version-file: ${{ github.action_path }}/.python-version
- name: Install Requested Python Version
if: inputs.install_python && inputs.python_version
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
with:
python-version: ${{ inputs.python_version }}
- name: Install uv
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
- name: Install Mistral Vibe
shell: bash
working-directory: ${{ github.action_path }}
run: |
uv sync --locked --all-extras --dev
- name: Run Mistral Vibe
id: run-mistral-vibe
shell: bash
working-directory: ${{ github.action_path }}
env:
MISTRAL_API_KEY: "${{ inputs.MISTRAL_API_KEY }}"
run: |
# We want to make sure that any text passed in here
# doesn't have special bash characters (<, >, &, etc...)
ESCAPED_PROMPT=$(printf '%q' "${{ inputs.prompt }}")
# Change back to the original working directory for the tool to work
cd "${{ github.workspace }}"
uv run --directory "${{ github.action_path }}" vibe \
--auto-approve \
-p "${ESCAPED_PROMPT}"

View File

@@ -0,0 +1,35 @@
id = "mistral-vibe"
name = "Mistral Vibe"
description = "Lightning-fast AI agent that actually gets things done"
version = "1.0.0"
schema_version = 1
authors = ["Mistral AI"]
repository = "https://github.com/mistralai/mistral-vibe"
[agent_servers.mistral-vibe-agent]
name = "Mistral Vibe"
icon = "./icons/mistral_vibe.svg"
[agent_servers.mistral-vibe-agent.targets.darwin-aarch64]
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-darwin-aarch64-1.0.0.zip"
cmd = "./vibe-acp"
[agent_servers.mistral-vibe-agent.targets.darwin-x86_64]
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-darwin-x86_64-1.0.0.zip"
cmd = "./vibe-acp"
# [agent_servers.mistral-vibe-agent.targets.linux-aarch64]
# archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-linux-aarch64-1.0.0.zip"
# cmd = "./vibe-acp"
[agent_servers.mistral-vibe-agent.targets.linux-x86_64]
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-linux-x86_64-1.0.0.zip"
cmd = "./vibe-acp"
# [agent_servers.mistral-vibe-agent.targets.windows-aarch64]
# archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-windows-aarch64-1.0.0.zip"
# cmd = "./vibe-acp.exe"
[agent_servers.mistral-vibe-agent.targets.windows-x86_64]
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-windows-x86_64-1.0.0.zip"
cmd = "./vibe-acp.exe"

View File

@@ -0,0 +1,13 @@
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
<rect width="512" height="512" fill="#24211E"/>
<rect x="113.778" y="113.778" width="56.8889" height="56.8889" fill="white"/>
<rect x="341.333" y="113.778" width="56.8889" height="56.8889" fill="white"/>
<rect x="227.556" y="284.444" width="56.8889" height="56.8889" fill="white"/>
<rect x="113.778" y="284.444" width="56.8889" height="56.8889" fill="white"/>
<rect x="341.333" y="284.444" width="56.8889" height="56.8889" fill="white"/>
<rect x="113.778" y="170.667" width="113.778" height="56.8889" fill="white"/>
<rect x="56.8887" y="341.333" width="170.667" height="56.8889" fill="white"/>
<rect x="284.444" y="341.333" width="170.667" height="56.8889" fill="white"/>
<rect x="113.778" y="227.556" width="284.444" height="56.8889" fill="white"/>
<rect x="284.444" y="170.667" width="113.778" height="56.8889" fill="white"/>
</svg>

After

Width:  |  Height:  |  Size: 935 B

133
flake.lock generated Normal file
View File

@@ -0,0 +1,133 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1763283776,
"narHash": "sha256-Y7TDFPK4GlqrKrivOcsHG8xSGqQx3A6c+i7novT85Uk=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "50a96edd8d0db6cc8db57dab6bb6d6ee1f3dc49a",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1761781027,
"narHash": "sha256-YDvxPAm2WnxrznRqWwHLjryBGG5Ey1ATEJXrON+TWt8=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "795a980d25301e5133eca37adae37283ec3c8e66",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs",
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"uv2nix": "uv2nix"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1763349549,
"narHash": "sha256-GQKYN9j8HOh09RW2I739tyu87ygcsAmpJJ32FspWVJ8=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "071b718279182c5585f74939c2902c202f93f588",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

144
flake.nix Normal file
View File

@@ -0,0 +1,144 @@
{
description = "Mistral Vibe!";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
};
outputs = {
self,
nixpkgs,
flake-utils,
uv2nix,
pyproject-nix,
pyproject-build-systems,
...
}:
flake-utils.lib.eachDefaultSystem (system: let
inherit (nixpkgs) lib;
workspace = uv2nix.lib.workspace.loadWorkspace {workspaceRoot = ./.;};
overlay = workspace.mkPyprojectOverlay {
sourcePreference = "wheel"; # sdist if you want
};
pyprojectOverrides = final: prev: {
# NOTE: If a package complains about a missing dependency (such
# as setuptools), you can add it here.
untokenize = prev.untokenize.overrideAttrs (old: {
buildInputs = (old.buildInputs or []) ++ final.resolveBuildSystem {setuptools = [];};
});
};
pkgs = import nixpkgs {
inherit system;
};
python = pkgs.python312;
# Construct package set
pythonSet =
# Use base package set from pyproject.nix builders
(pkgs.callPackage pyproject-nix.build.packages {
inherit python;
}).overrideScope
(
lib.composeManyExtensions [
pyproject-build-systems.overlays.default
overlay
pyprojectOverrides
]
);
in {
packages.default = pythonSet.mkVirtualEnv "mistralai-vibe-env" workspace.deps.default;
apps = {
default = {
type = "app";
program = "${self.packages.${system}.default}/bin/vibe";
};
};
devShells = {
default = let
editableOverlay = workspace.mkEditablePyprojectOverlay {
root = "$REPO_ROOT";
};
editablePythonSet = pythonSet.overrideScope (
lib.composeManyExtensions [
editableOverlay
# Apply fixups for building an editable package of your workspace packages
(final: prev: {
mistralai-vibe = prev.mistralai-vibe.overrideAttrs (old: {
# It's a good idea to filter the sources going into an editable build
# so the editable package doesn't have to be rebuilt on every change.
src = lib.fileset.toSource {
root = old.src;
fileset = lib.fileset.unions [
(old.src + "/pyproject.toml")
(old.src + "/README.md")
];
};
nativeBuildInputs =
old.nativeBuildInputs
++ final.resolveBuildSystem {
editables = [];
};
});
})
]
);
virtualenv = editablePythonSet.mkVirtualEnv "mistralai-vibe-dev-env" workspace.deps.all;
in
pkgs.mkShell {
packages = [
virtualenv
pkgs.uv
];
env = {
# Don't create venv using uv
UV_NO_SYNC = "1";
# Force uv to use Python interpreter from venv
UV_PYTHON = "${virtualenv}/bin/python";
# Prevent uv from downloading managed Python's
UV_PYTHON_DOWNLOADS = "never";
};
shellHook = ''
# Undo dependency propagation by nixpkgs.
unset PYTHONPATH
# Get repository root using git. This is expanded at runtime by the editable `.pth` machinery.
export REPO_ROOT=$(git rev-parse --show-toplevel)
'';
};
};
});
}

157
pyproject.toml Normal file
View File

@@ -0,0 +1,157 @@
[project]
name = "mistral-vibe"
version = "1.0.0"
description = "Minimal CLI coding agent by Mistral"
readme = "README.md"
requires-python = ">=3.12"
license = { text = "Apache-2.0" }
authors = [{ name = "Mistral AI" }]
keywords = [
"ai",
"cli",
"coding-assistant",
"mistral",
"llm",
"developer-tools",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.12",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Utilities",
]
dependencies = [
"agent-client-protocol==0.6.3",
"aiofiles>=24.1.0",
"httpx>=0.28.1",
"mcp>=1.14.0",
"mistralai==1.9.11",
"pexpect>=4.9.0",
"packaging>=24.1",
"pydantic>=2.12.4",
"pydantic-settings>=2.12.0",
"python-dotenv>=1.0.0",
"pytest-xdist>=3.8.0",
"rich>=14.0.0",
"textual>=1.0.0",
"tomli-w>=1.2.0",
"watchfiles>=1.1.1",
"pyperclip>=1.11.0",
]
[project.urls]
Homepage = "https://github.com/mistralai/mistral-vibe"
Repository = "https://github.com/mistralai/mistral-vibe"
Issues = "https://github.com/mistralai/mistral-vibe/issues"
Documentation = "https://github.com/mistralai/mistral-vibe#readme"
[build-system]
requires = ["hatchling", "hatch-vcs", "editables"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
include = ["vibe/"]
[project.scripts]
vibe = "vibe.cli.entrypoint:main"
vibe-acp = "vibe.acp.entrypoint:main"
[tool.uv]
package = true
required-version = ">=0.8.0"
[dependency-groups]
dev = [
"pre-commit>=4.2.0",
"pyright>=1.1.403",
"pytest>=8.3.5",
"pytest-asyncio>=1.2.0",
"pytest-timeout>=2.4.0",
"pytest-textual-snapshot>=1.1.0",
"respx>=0.22.0",
"ruff>=0.14.5",
"twine>=5.0.0",
"typos>=1.34.0",
"vulture>=2.14",
"pyinstaller>=6.17.0",
]
[tool.pyright]
pythonVersion = "3.12"
reportMissingTypeStubs = false
reportPrivateImportUsage = false
include = ["vibe/**/*.py", "tests/**/*.py"]
venvPath = "."
venv = ".venv"
[tool.ruff]
include = ["vibe/**/*.py", "tests/**/*.py"]
line-length = 88
target-version = "py312"
preview = true
[tool.ruff.format]
skip-magic-trailing-comma = true
[tool.ruff.lint]
select = [
"F",
"I",
"D2",
"UP",
"TID",
"ANN",
"PLR",
"B0",
"B905",
"DOC102",
"RUF022",
"RUF010",
"RUF012",
"RUF019",
"RUF100",
]
ignore = ["D203", "D205", "D213", "ANN401", "PLR6301"]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["ANN", "PLR"]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"
[tool.ruff.lint.isort]
known-first-party = ["vibe"]
force-sort-within-sections = true
split-on-trailing-comma = true
combine-as-imports = true
force-wrap-aliases = false
order-by-type = true
required-imports = ["from __future__ import annotations"]
[tool.ruff.lint.pylint]
max-statements = 50
max-branches = 15
max-locals = 15
max-args = 9
max-returns = 6
max-nested-blocks = 4
[tool.vulture]
ignore_decorators = ["@*"]
[tool.pytest.ini_options]
addopts = "-vvvv -q -n auto --durations=5 --import-mode=importlib"
timeout = 10

20
scripts/README.md Normal file
View File

@@ -0,0 +1,20 @@
# Project Management Scripts
This directory contains scripts that support project versioning and deployment workflows.
## Versioning
### Usage
```bash
# Bump major version (1.0.0 -> 2.0.0)
uv run scripts/bump_version.py major
# Bump minor version (1.0.0 -> 1.1.0)
uv run scripts/bump_version.py minor
# Bump patch/micro version (1.0.0 -> 1.0.1)
uv run scripts/bump_version.py micro
# or
uv run scripts/bump_version.py patch
```

138
scripts/bump_version.py Executable file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
"""Version bumping script for semver versioning.
This script increments the version in pyproject.toml based on the specified bump type:
- major: 1.0.0 -> 2.0.0
- minor: 1.0.0 -> 1.1.0
- micro/patch: 1.0.0 -> 1.0.1
"""
from __future__ import annotations
import argparse
from pathlib import Path
import re
import subprocess
import sys
from typing import Literal, get_args
BumpType = Literal["major", "minor", "micro", "patch"]
BUMP_TYPES = get_args(BumpType)
def parse_version(version_str: str) -> tuple[int, int, int]:
match = re.match(r"^(\d+)\.(\d+)\.(\d+)$", version_str.strip())
if not match:
raise ValueError(f"Invalid version format: {version_str}")
return int(match.group(1)), int(match.group(2)), int(match.group(3))
def format_version(major: int, minor: int, patch: int) -> str:
return f"{major}.{minor}.{patch}"
def bump_version(version: str, bump_type: BumpType) -> str:
major, minor, patch = parse_version(version)
match bump_type:
case "major":
return format_version(major + 1, 0, 0)
case "minor":
return format_version(major, minor + 1, 0)
case "micro" | "patch":
return format_version(major, minor, patch + 1)
def update_hard_values_files(filepath: str, patterns: list[tuple[str, str]]) -> None:
path = Path(filepath)
if not path.exists():
raise FileNotFoundError(f"{filepath} not found in current directory")
# Replace patterns
for pattern, replacement in patterns:
content = path.read_text()
updated_content = re.sub(pattern, replacement, content, flags=re.MULTILINE)
if updated_content == content:
raise ValueError(f"pattern {pattern} not found in {filepath}")
path.write_text(updated_content)
print(f"Updated version in {filepath}")
def get_current_version() -> str:
pyproject_path = Path("pyproject.toml")
if not pyproject_path.exists():
raise FileNotFoundError("pyproject.toml not found in current directory")
content = pyproject_path.read_text()
# Find version line
version_match = re.search(r'^version = "([^"]+)"$', content, re.MULTILINE)
if not version_match:
raise ValueError("Version not found in pyproject.toml")
return version_match.group(1)
def main() -> None:
parser = argparse.ArgumentParser(
description="Bump semver version in pyproject.toml",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
uv run scripts/bump_version.py major # 1.0.0 -> 2.0.0
uv run scripts/bump_version.py minor # 1.0.0 -> 1.1.0
uv run scripts/bump_version.py micro # 1.0.0 -> 1.0.1
uv run scripts/bump_version.py patch # 1.0.0 -> 1.0.1
""",
)
parser.add_argument(
"bump_type", choices=BUMP_TYPES, help="Type of version bump to perform"
)
args = parser.parse_args()
try:
# Get current version
current_version = get_current_version()
print(f"Current version: {current_version}")
# Calculate new version
new_version = bump_version(current_version, args.bump_type)
print(f"New version: {new_version}")
# Update pyproject.toml
update_hard_values_files(
"pyproject.toml",
[(f'version = "{current_version}"', f'version = "{new_version}"')],
)
# Update extension.toml
update_hard_values_files(
"distribution/zed/extension.toml",
[
(f'version = "{current_version}"', f'version = "{new_version}"'),
(
f"releases/download/v{current_version}",
f"releases/download/v{new_version}",
),
(f"-{current_version}.zip", f"-{new_version}.zip"),
],
)
subprocess.run(["uv", "lock"], check=True)
print(f"Successfully bumped version from {current_version} to {new_version}")
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

128
scripts/install.sh Executable file
View File

@@ -0,0 +1,128 @@
#!/usr/bin/env bash
# Mistral Vibe Installation Script
# This script installs uv if not present and then installs mistral-vibe using uv
set -euo pipefail
# Colors for output
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
function error() {
echo -e "${RED}[ERROR]${NC} $1" >&2
}
function info() {
echo -e "${BLUE}[INFO]${NC} $1"
}
function success() {
echo -e "${GREEN}[SUCCESS]${NC} $1"
}
function warning() {
echo -e "${YELLOW}[WARNING]${NC} $1"
}
function check_platform() {
local platform=$(uname -s)
if [[ "$platform" == "Linux" ]]; then
info "Detected Linux platform"
PLATFORM="linux"
elif [[ "$platform" == "Darwin" ]]; then
info "Detected macOS platform"
PLATFORM="macos"
else
error "Unsupported platform: $platform"
error "This installation script currently only supports Linux and macOS"
exit 1
fi
}
function check_uv_installed() {
if command -v uv &> /dev/null; then
info "uv is already installed: $(uv --version)"
UV_INSTALLED=true
else
info "uv is not installed"
UV_INSTALLED=false
fi
}
function install_uv() {
info "Installing uv using the official Astral installer..."
if ! command -v curl &> /dev/null; then
error "curl is required to install uv. Please install curl first."
exit 1
fi
if curl -LsSf https://astral.sh/uv/install.sh | sh; then
success "uv installed successfully"
export PATH="$HOME/.local/bin:$PATH"
if ! command -v uv &> /dev/null; then
warning "uv was installed but not found in PATH for this session"
warning "You may need to restart your terminal or run:"
warning " export PATH=\"\$HOME/.cargo/bin:\$HOME/.local/bin:\$PATH\""
fi
else
error "Failed to install uv"
exit 1
fi
}
function install_vibe() {
info "Installing mistral-vibe from GitHub repository using uv..."
uv tool install mistral-vibe
success "Mistral Vibe installed successfully! (commands: vibe, vibe-acp)"
}
function main() {
echo
echo "██████████████████░░"
echo "██████████████████░░"
echo "████ ██████ ████░░"
echo "████ ██ ████░░"
echo "████ ████░░"
echo "████ ██ ██ ████░░"
echo "██ ██ ██░░"
echo "██████████████████░░"
echo "██████████████████░░"
echo
echo "Starting Mistral Vibe installation..."
echo
check_platform
check_uv_installed
if [[ "$UV_INSTALLED" == "false" ]]; then
install_uv
fi
install_vibe
if command -v vibe &> /dev/null; then
success "Installation completed successfully!"
echo
echo "You can now run vibe with:"
echo " vibe"
echo
echo "Or for ACP mode:"
echo " vibe-acp"
else
error "Installation completed but 'vibe' command not found"
error "Please check your installation and PATH settings"
exit 1
fi
}
main

5
tests/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from pathlib import Path
TESTS_ROOT = Path(__file__).parent

925
tests/acp/test_acp.py Normal file
View File

@@ -0,0 +1,925 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
import json
import os
from typing import Any
from acp import (
InitializeRequest,
NewSessionRequest,
PromptRequest,
ReadTextFileRequest,
ReadTextFileResponse,
RequestPermissionRequest,
RequestPermissionResponse,
WriteTextFileRequest,
)
from acp.schema import (
AgentCapabilities,
AllowedOutcome,
DeniedOutcome,
Implementation,
InitializeResponse,
McpCapabilities,
NewSessionResponse,
PromptCapabilities,
PromptResponse,
SessionNotification,
TextContentBlock,
)
from pydantic import BaseModel
import pytest
from tests import TESTS_ROOT
from tests.mock.utils import get_mocking_env, mock_llm_chunk
from vibe.acp.utils import ToolOption
from vibe.core.types import FunctionCall, ToolCall
RESPONSE_TIMEOUT = 2.0
MOCK_ENTRYPOINT_PATH = "tests/mock/mock_entrypoint.py"
PLAYGROUND_DIR = TESTS_ROOT / "playground"
class JsonRpcRequest(BaseModel):
jsonrpc: str = "2.0"
id: int | str
method: str
params: Any | None = None
class JsonRpcError(BaseModel):
code: int
message: str
data: Any | None = None
class JsonRpcResponse(BaseModel):
jsonrpc: str = "2.0"
id: int | str | None = None
result: Any | None = None
error: JsonRpcError | None = None
class JsonRpcNotification(BaseModel):
jsonrpc: str = "2.0"
method: str
params: Any | None = None
type JsonRpcMessage = JsonRpcResponse | JsonRpcNotification | JsonRpcRequest
class InitializeJsonRpcRequest(JsonRpcRequest):
method: str = "initialize"
params: InitializeRequest | None = None
class InitializeJsonRpcResponse(JsonRpcResponse):
result: InitializeResponse | None = None
class NewSessionJsonRpcRequest(JsonRpcRequest):
method: str = "session/new"
params: NewSessionRequest | None = None
class NewSessionJsonRpcResponse(JsonRpcResponse):
result: NewSessionResponse | None = None
class PromptJsonRpcRequest(JsonRpcRequest):
method: str = "session/prompt"
params: PromptRequest | None = None
class PromptJsonRpcResponse(JsonRpcResponse):
result: PromptResponse | None = None
class UpdateJsonRpcNotification(JsonRpcNotification):
method: str = "session/update"
params: SessionNotification | None = None
class RequestPermissionJsonRpcRequest(JsonRpcRequest):
method: str = "session/request_permission"
params: RequestPermissionRequest | None = None
class RequestPermissionJsonRpcResponse(JsonRpcResponse):
result: RequestPermissionResponse | None = None
class ReadTextFileJsonRpcRequest(JsonRpcRequest):
method: str = "fs/read_text_file"
params: ReadTextFileRequest | None = None
class ReadTextFileJsonRpcResponse(JsonRpcResponse):
result: ReadTextFileResponse | None = None
class WriteTextFileJsonRpcRequest(JsonRpcRequest):
method: str = "fs/write_text_file"
params: WriteTextFileRequest | None = None
class WriteTextFileJsonRpcResponse(JsonRpcResponse):
result: None = None
async def get_acp_agent_process(
mock: bool = True, mock_env: dict[str, str] | None = None
) -> AsyncGenerator[asyncio.subprocess.Process]:
current_env = os.environ.copy()
cmd = ["uv", "run", MOCK_ENTRYPOINT_PATH if mock else "vibe-acp"]
process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=TESTS_ROOT.parent,
env={
**current_env,
**(mock_env or {}),
**({"MISTRAL_API_KEY": "mock"} if mock else {}),
},
)
try:
yield process
finally:
# Cleanup
if process.returncode is None:
process.terminate()
try:
await asyncio.wait_for(process.wait(), timeout=0.5)
except TimeoutError:
process.kill()
await process.wait()
async def send_json_rpc(
process: asyncio.subprocess.Process, message: JsonRpcMessage
) -> None:
if process.stdin is None:
raise RuntimeError("Process stdin not available")
request = message.model_dump_json()
request_json = request + "\n"
process.stdin.write(request_json.encode())
await process.stdin.drain()
async def read_response(
process: asyncio.subprocess.Process, timeout: float = RESPONSE_TIMEOUT
) -> str | None:
if process.stdout is None:
raise RuntimeError("Process stdout not available")
try:
# Keep reading lines until we find a valid JSON line
while True:
line = await asyncio.wait_for(process.stdout.readline(), timeout=timeout)
if not line:
return None
line_str = line.decode().strip()
if not line_str:
continue
try:
json.loads(line_str)
return line_str
except json.JSONDecodeError:
# Not JSON, skip it (it's a log message)
continue
except TimeoutError:
return None
async def read_response_for_id(
process: asyncio.subprocess.Process,
expected_id: int | str,
timeout: float = RESPONSE_TIMEOUT,
) -> str | None:
loop = asyncio.get_running_loop()
end_time = loop.time() + timeout
while (remaining := end_time - loop.time()) > 0:
response = await read_response(process, timeout=remaining)
if response is None:
return None
response_json = json.loads(response)
if response_json.get("id") == expected_id:
return response
print(
f"Skipping response with id={response_json.get('id')}, expecting {expected_id}"
)
return None
async def read_multiple_responses(
process: asyncio.subprocess.Process,
max_count: int = 10,
timeout_per_response: float = RESPONSE_TIMEOUT,
) -> list[str]:
responses = []
for _ in range(max_count):
response = await read_response(process, timeout=timeout_per_response)
if response:
responses.append(response)
else:
break
return responses
def parse_conversation(message_texts: list[str]) -> list[JsonRpcMessage]:
parsed_messages: list[JsonRpcMessage] = []
for message_text in message_texts:
message_json = json.loads(message_text)
cls = None
has_method = message_json.get("method", None) is not None
has_id = message_json.get("id", None) is not None
has_result = message_json.get("result", None) is not None
is_request = has_method and has_id
is_notification = has_method and not has_id
is_response = has_result
if is_request:
match message_json.get("method"):
case "session/prompt":
cls = PromptJsonRpcRequest
case "session/request_permission":
cls = RequestPermissionJsonRpcRequest
case "fs/read_text_file":
cls = ReadTextFileJsonRpcRequest
case "fs/write_text_file":
cls = WriteTextFileJsonRpcRequest
elif is_notification:
match message_json.get("method"):
case "session/update":
cls = UpdateJsonRpcNotification
elif is_response:
# For responses, since we don't know the method, we need to find
# the matching request.
matching_request = next(
(
m
for m in parsed_messages
if isinstance(m, JsonRpcRequest) and m.id == message_json.get("id")
),
None,
)
if matching_request is None:
# No matching request found in the conversation, it most probably was
# not included in the conversation. We use a generic response class.
cls = JsonRpcResponse
else:
match matching_request.method:
case "session/prompt":
cls = PromptJsonRpcResponse
case "session/request_permission":
cls = RequestPermissionJsonRpcResponse
case "fs/read_text_file":
cls = ReadTextFileJsonRpcResponse
case "fs/write_text_file":
cls = WriteTextFileJsonRpcResponse
if cls is None:
raise ValueError(f"No valid message class found for {message_json}")
parsed_messages.append(cls.model_validate(message_json))
return parsed_messages
async def initialize_session(acp_agent_process: asyncio.subprocess.Process) -> str:
await send_json_rpc(
acp_agent_process,
InitializeJsonRpcRequest(id=1, params=InitializeRequest(protocolVersion=1)),
)
initialize_response = await read_response_for_id(
acp_agent_process, expected_id=1, timeout=5.0
)
assert initialize_response is not None
await send_json_rpc(
acp_agent_process,
NewSessionJsonRpcRequest(
id=2, params=NewSessionRequest(cwd=str(PLAYGROUND_DIR), mcpServers=[])
),
)
session_response = await read_response_for_id(acp_agent_process, expected_id=2)
assert session_response is not None
session_response_json = json.loads(session_response)
session_response_obj = NewSessionJsonRpcResponse.model_validate(
session_response_json
)
assert session_response_obj.result is not None, "No result in response"
return session_response_obj.result.sessionId
class TestInitialization:
@pytest.mark.asyncio
async def test_initialize_request_response(self) -> None:
mock_env = get_mocking_env()
async for process in get_acp_agent_process(mock_env=mock_env):
await send_json_rpc(
process,
InitializeJsonRpcRequest(
id=1, params=InitializeRequest(protocolVersion=1)
),
)
text_response = await read_response(process, timeout=10.0)
assert text_response is not None, "No response to initialize"
response_json = json.loads(text_response)
response = InitializeJsonRpcResponse.model_validate(response_json)
assert response.error is None, f"JSON-RPC error: {response.error}"
assert response.result is not None, "No result in response"
assert response.result.protocolVersion == 1
assert response.result.agentCapabilities == AgentCapabilities(
loadSession=False,
promptCapabilities=PromptCapabilities(
audio=False, embeddedContext=True, image=False
),
mcpCapabilities=McpCapabilities(http=False, sse=False),
)
assert response.result.agentInfo == Implementation(
name="@mistralai/mistral-vibe", title="Mistral Vibe", version="0.1.0"
)
vibe_setup_method = next(
(
method
for method in response.result.authMethods or []
if method.id == "vibe-setup"
),
None,
)
assert vibe_setup_method is not None, "vibe-setup auth not found"
assert vibe_setup_method.field_meta is not None
assert "terminal-auth" in vibe_setup_method.field_meta.keys()
class TestSessionManagement:
@pytest.mark.asyncio
async def test_multiple_sessions_unique_ids(self) -> None:
mock_env = get_mocking_env(mock_chunks=[mock_llm_chunk() for _ in range(3)])
async for process in get_acp_agent_process(mock_env=mock_env):
await send_json_rpc(
process,
InitializeJsonRpcRequest(
id=1, params=InitializeRequest(protocolVersion=1)
),
)
await read_response_for_id(process, expected_id=1, timeout=5.0)
session_ids = []
for i in range(3):
await send_json_rpc(
process,
NewSessionJsonRpcRequest(
id=i + 2,
params=NewSessionRequest(
cwd=str(PLAYGROUND_DIR), mcpServers=[]
),
),
)
text_response = await read_response_for_id(
process, expected_id=i + 2, timeout=RESPONSE_TIMEOUT
)
assert text_response is not None
response_json = json.loads(text_response)
response = NewSessionJsonRpcResponse.model_validate(response_json)
assert response.error is None, f"JSON-RPC error: {response.error}"
assert response.result is not None, "No result in response"
session_ids.append(response.result.sessionId)
assert len(set(session_ids)) == 3
class TestSessionUpdates:
@pytest.mark.asyncio
async def test_agent_message_chunk_structure(self) -> None:
mock_env = get_mocking_env([mock_llm_chunk(content="Hi") for _ in range(2)])
async for process in get_acp_agent_process(mock_env=mock_env):
# Check stderr for error details if process failed
if process.returncode is not None and process.stderr:
stderr_data = await process.stderr.read()
if stderr_data:
# Log stderr for debugging test failures
pass # Could add proper logging here if needed
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
sessionId=session_id,
prompt=[TextContentBlock(type="text", text="Just say hi")],
),
),
)
text_response = await read_response(process)
assert text_response is not None
response = UpdateJsonRpcNotification.model_validate(
json.loads(text_response)
)
assert response.params is not None
assert response.params.update.sessionUpdate == "agent_message_chunk"
assert response.params.update.content is not None
assert response.params.update.content.type == "text"
assert response.params.update.content.text is not None
assert response.params.update.content.text == "Hi"
@pytest.mark.asyncio
async def test_tool_call_update_structure(self) -> None:
mock_env = get_mocking_env([
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="grep", arguments='{"pattern": "auth"}'
),
type="function",
index=0,
)
],
name="bash",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The files containing the pattern 'auth' are ...",
finish_reason="stop",
),
])
async for process in get_acp_agent_process(mock_env=mock_env):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
sessionId=session_id,
prompt=[
TextContentBlock(
type="text",
text="Show me files that are related to auth",
)
],
),
),
)
text_responses = await read_multiple_responses(process, max_count=10)
assert len(text_responses) > 0
responses = [
UpdateJsonRpcNotification.model_validate(json.loads(r))
for r in text_responses
]
tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.sessionUpdate == "tool_call"
),
None,
)
assert tool_call is not None
assert tool_call.params is not None
assert tool_call.params.update is not None
assert tool_call.params.update.sessionUpdate == "tool_call"
assert tool_call.params.update.kind == "search"
assert tool_call.params.update.title == "grep: 'auth'"
assert (
tool_call.params.update.rawInput
== '{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}'
)
async def start_session_with_request_permission(
process: asyncio.subprocess.Process, prompt: str
) -> RequestPermissionJsonRpcRequest:
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
sessionId=session_id,
prompt=[TextContentBlock(type="text", text=prompt)],
),
),
)
text_responses = await read_multiple_responses(
process, max_count=15, timeout_per_response=2.0
)
responses = parse_conversation(text_responses)
last_response = responses[-1]
assert isinstance(last_response, RequestPermissionJsonRpcRequest)
assert last_response.params is not None
assert len(last_response.params.options) == 3
return last_response
@pytest.mark.skip(
reason="Disabled until we have a way to properly mock the fs and acp interactions"
)
class TestToolCallStructure:
@pytest.mark.asyncio
async def test_tool_call_request_permission_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"test.txt","content":"hello, world!"'
',"overwrite":true}',
),
type="function",
index=0,
)
],
name="write_file",
finish_reason="stop",
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
sessionId=session_id,
prompt=[
TextContentBlock(
type="text",
text="Create a new file named test.txt "
"with content 'hello, world!'",
)
],
),
),
)
text_responses = await read_multiple_responses(process, max_count=3)
responses = parse_conversation(text_responses)
# Look for tool call request permission updates
permission_requests = [
r for r in responses if isinstance(r, RequestPermissionJsonRpcRequest)
]
assert len(permission_requests) > 0, (
"No tool call permission requests found"
)
first_request = permission_requests[0]
assert first_request.params is not None
assert first_request.params.toolCall is not None
assert first_request.params.toolCall.toolCallId is not None
@pytest.mark.asyncio
async def test_tool_call_update_approved_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"test.txt","content":"hello, world!"'
',"overwrite":true}',
),
type="function",
index=0,
)
],
name="write_file",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The file test.txt has been created", finish_reason="stop"
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
permission_request = await start_session_with_request_permission(
process, "Create a file named test.txt"
)
assert permission_request.params is not None
selected_option_id = ToolOption.ALLOW_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", optionId=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=7)
responses = parse_conversation(text_responses)
approved_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update is not None
and r.params.update.sessionUpdate == "tool_call_update"
and r.params.update.toolCallId
== (permission_request.params.toolCall.toolCallId)
and r.params.update.status == "completed"
),
None,
)
assert approved_tool_call is not None
@pytest.mark.asyncio
async def test_tool_call_update_rejected_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"test.txt","content":"hello, world!"'
',"overwrite":false}',
),
type="function",
index=0,
)
],
name="write_file",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The file test.txt has not been created, "
"because you rejected the permission request",
finish_reason="stop",
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
permission_request = await start_session_with_request_permission(
process, "Create a file named test.txt"
)
assert permission_request.params is not None
selected_option_id = ToolOption.REJECT_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", optionId=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=5)
responses = parse_conversation(text_responses)
rejected_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update.sessionUpdate == "tool_call_update"
and r.params.update.toolCallId
== (permission_request.params.toolCall.toolCallId)
and r.params.update.status == "failed"
),
None,
)
assert rejected_tool_call is not None
@pytest.mark.skip(reason="Long running tool call updates are not implemented yet")
@pytest.mark.asyncio
async def test_tool_call_in_progress_update_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="bash",
arguments='{"command":"sleep 3","timeout":null}',
),
type="function",
)
],
name="bash",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The command sleep 3 has been run", finish_reason="stop"
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
session_id = await initialize_session(process)
await send_json_rpc(
process,
PromptJsonRpcRequest(
id=3,
params=PromptRequest(
sessionId=session_id,
prompt=[
TextContentBlock(
type="text", text="Run sleep 3 in the current directory"
)
],
),
),
)
text_responses = await read_multiple_responses(process, max_count=4)
responses = parse_conversation(text_responses)
# Look for tool call in progress updates
in_progress_calls = [
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.sessionUpdate == "tool_call_update"
and r.params.update.status == "in_progress"
]
assert len(in_progress_calls) > 0, (
"No tool call in progress updates found for a long running command"
)
@pytest.mark.asyncio
async def test_tool_call_result_update_failure_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"/test.txt","content":"hello, world!"'
',"overwrite":true}',
),
type="function",
index=0,
)
],
name="write_file",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The file /test.txt has not been created "
"because it's outside the project directory",
finish_reason="stop",
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
permission_request = await start_session_with_request_permission(
process, "Create a file named /test.txt"
)
assert permission_request.params is not None
selected_option_id = ToolOption.ALLOW_ONCE
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=AllowedOutcome(
outcome="selected", optionId=selected_option_id
)
),
),
)
text_responses = await read_multiple_responses(process, max_count=7)
responses = parse_conversation(text_responses)
# Look for tool call result failure updates
failure_result = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.params is not None
and r.params.update.sessionUpdate == "tool_call_update"
and r.params.update.status == "failed"
and r.params.update.rawOutput is not None
and r.params.update.toolCallId is not None
),
None,
)
assert failure_result is not None
class TestCancellationStructure:
@pytest.mark.skip(
reason="Proper cancellation is not implemented yet, we still need to return "
"the right end_turn and be able to cancel at any point in time "
"(and not only at tool call time)"
)
@pytest.mark.asyncio
async def test_tool_call_update_cancelled_structure(self) -> None:
custom_results = [
mock_llm_chunk(content="Hey"),
mock_llm_chunk(
tool_calls=[
ToolCall(
function=FunctionCall(
name="write_file",
arguments='{"path":"test.txt","content":"hello, world!"'
',"overwrite":false}',
),
type="function",
index=0,
)
],
name="write_file",
finish_reason="tool_calls",
),
mock_llm_chunk(
content="The file test.txt has not been created, "
"because you cancelled the permission request",
finish_reason="stop",
),
]
mock_env = get_mocking_env(custom_results)
async for process in get_acp_agent_process(mock_env=mock_env):
permission_request = await start_session_with_request_permission(
process, "Create a file named test.txt"
)
assert permission_request.params is not None
await send_json_rpc(
process,
RequestPermissionJsonRpcResponse(
id=permission_request.id,
result=RequestPermissionResponse(
outcome=DeniedOutcome(outcome="cancelled")
),
),
)
text_responses = await read_multiple_responses(process, max_count=5)
responses = parse_conversation(text_responses)
assert len(responses) == 2, (
"There should be only 2 responses: "
"the tool call update and the prompt end turn"
)
cancelled_tool_call = next(
(
r
for r in responses
if isinstance(r, UpdateJsonRpcNotification)
and r.method == "session/update"
and r.params is not None
and r.params.update.sessionUpdate == "tool_call_update"
and r.params.update.toolCallId
== (permission_request.params.toolCall.toolCallId)
and r.params.update.status == "failed"
),
None,
)
assert cancelled_tool_call is not None
cancelled_prompt_response = next(
(
r
for r in responses
if isinstance(r, PromptJsonRpcResponse)
and r.result is not None
and r.result.stopReason == "cancelled"
),
None,
)
assert cancelled_prompt_response is not None

536
tests/acp/test_bash.py Normal file
View File

@@ -0,0 +1,536 @@
from __future__ import annotations
import asyncio
from acp.schema import TerminalOutputResponse, WaitForTerminalExitResponse
import pytest
from vibe.acp.tools.builtins.bash import AcpBashState, Bash
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.bash import BashArgs, BashResult, BashToolConfig
class MockTerminalHandle:
def __init__(
self,
terminal_id: str = "test_terminal_123",
exit_code: int | None = 0,
output: str = "test output",
wait_delay: float = 0.01,
) -> None:
self.id = terminal_id
self._exit_code = exit_code
self._output = output
self._wait_delay = wait_delay
self._killed = False
async def wait_for_exit(self) -> WaitForTerminalExitResponse:
await asyncio.sleep(self._wait_delay)
return WaitForTerminalExitResponse(exitCode=self._exit_code)
async def current_output(self) -> TerminalOutputResponse:
return TerminalOutputResponse(output=self._output, truncated=False)
async def kill(self) -> None:
self._killed = True
async def release(self) -> None:
pass
class MockConnection:
def __init__(self, terminal_handle: MockTerminalHandle | None = None) -> None:
self._terminal_handle = terminal_handle or MockTerminalHandle()
self._create_terminal_called = False
self._session_update_called = False
self._create_terminal_error: Exception | None = None
self._last_create_request = None
async def createTerminal(self, request) -> MockTerminalHandle:
self._create_terminal_called = True
self._last_create_request = request
if self._create_terminal_error:
raise self._create_terminal_error
return self._terminal_handle
async def sessionUpdate(self, notification) -> None:
self._session_update_called = True
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
def acp_bash_tool(mock_connection: MockConnection) -> Bash:
config = BashToolConfig()
# Use model_construct to bypass Pydantic validation for testing
state = AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session_123",
tool_call_id="test_tool_call_456",
)
return Bash(config=config, state=state)
class TestAcpBashBasic:
def test_get_name(self) -> None:
assert Bash.get_name() == "bash"
def test_get_summary_simple_command(self) -> None:
args = BashArgs(command="ls")
display = Bash.get_summary(args)
assert display == "ls"
def test_get_summary_with_timeout(self) -> None:
args = BashArgs(command="ls", timeout=10)
display = Bash.get_summary(args)
assert display == "ls (timeout 10s)"
def test_parse_command_simple(self) -> None:
tool = Bash(config=BashToolConfig(), state=AcpBashState())
env, command, args = tool._parse_command("ls")
assert env == []
assert command == "ls"
assert args == []
def test_parse_command_with_args(self) -> None:
tool = Bash(config=BashToolConfig(), state=AcpBashState())
env, command, args = tool._parse_command("ls -la src")
assert env == []
assert command == "ls"
assert args == ["-la", "src"]
def test_parse_command_with_env(self) -> None:
tool = Bash(config=BashToolConfig(), state=AcpBashState())
env, command, args = tool._parse_command("NODE_ENV=test DEBUG=1 npm test")
assert len(env) == 2
assert env[0].name == "NODE_ENV"
assert env[0].value == "test"
assert env[1].name == "DEBUG"
assert env[1].value == "1"
assert command == "npm"
assert args == ["test"]
def test_parse_command_with_env_value_contains_equals(self) -> None:
tool = Bash(config=BashToolConfig(), state=AcpBashState())
env, command, args = tool._parse_command(
"PATH=/usr/bin:/usr/local/bin echo hello"
)
assert len(env) == 1
assert env[0].name == "PATH"
assert env[0].value == "/usr/bin:/usr/local/bin"
assert command == "echo"
assert args == ["hello"]
class TestAcpBashExecution:
@pytest.mark.asyncio
async def test_run_success(
self, acp_bash_tool: Bash, mock_connection: MockConnection
) -> None:
from pathlib import Path
args = BashArgs(command="echo hello")
result = await acp_bash_tool.run(args)
assert isinstance(result, BashResult)
assert result.stdout == "test output"
assert result.stderr == ""
assert result.returncode == 0
assert mock_connection._create_terminal_called
# Verify CreateTerminalRequest was created correctly
request = mock_connection._last_create_request
assert request is not None
assert request.sessionId == "test_session_123"
assert request.command == "echo"
assert request.args == ["hello"]
assert request.cwd == str(Path.cwd()) # effective_workdir defaults to cwd
@pytest.mark.asyncio
async def test_run_creates_terminal_with_env_vars(
self, mock_connection: MockConnection
) -> None:
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="NODE_ENV=test npm run build")
await tool.run(args)
request = mock_connection._last_create_request
assert request is not None
assert len(request.env) == 1
assert request.env[0].name == "NODE_ENV"
assert request.env[0].value == "test"
assert request.command == "npm"
assert request.args == ["run", "build"]
@pytest.mark.asyncio
async def test_run_with_nonzero_exit_code(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(
terminal_id="custom_terminal", exit_code=1, output="error: command failed"
)
mock_connection._terminal_handle = custom_handle
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test_command")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Command failed: 'test_command'\nReturn code: 1\nStdout: error: command failed"
)
@pytest.mark.asyncio
async def test_run_create_terminal_failure(
self, mock_connection: MockConnection
) -> None:
mock_connection._create_terminal_error = RuntimeError("Connection failed")
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Failed to create terminal: RuntimeError('Connection failed')"
)
@pytest.mark.asyncio
async def test_run_without_connection(self) -> None:
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=None, session_id="test_session", tool_call_id="test_call"
),
)
args = BashArgs(command="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Connection not available in tool state. This tool can only be used within an ACP session."
)
@pytest.mark.asyncio
async def test_run_without_session_id(self) -> None:
mock_connection = MockConnection()
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id=None,
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Session ID not available in tool state. This tool can only be used within an ACP session."
)
@pytest.mark.asyncio
async def test_run_with_none_exit_code(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(
terminal_id="none_exit_terminal", exit_code=None, output="output"
)
mock_connection._terminal_handle = custom_handle
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test_command")
result = await tool.run(args)
assert result.returncode == 0
assert result.stdout == "output"
class TestAcpBashTimeout:
@pytest.mark.asyncio
async def test_run_with_timeout_raises_error_and_kills(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(
terminal_id="timeout_terminal",
output="partial output",
wait_delay=20, # Longer than the 1 second timeout
)
mock_connection._terminal_handle = custom_handle
# Use a config with different default timeout to verify args timeout overrides it
tool = Bash(
config=BashToolConfig(default_timeout=30),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="slow_command", timeout=1)
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == "Command timed out after 1s: 'slow_command'"
assert custom_handle._killed
@pytest.mark.asyncio
async def test_run_timeout_handles_kill_failure(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(
terminal_id="kill_failure_terminal",
wait_delay=20, # Longer than the 1 second timeout
)
mock_connection._terminal_handle = custom_handle
async def failing_kill() -> None:
raise RuntimeError("Kill failed")
custom_handle.kill = failing_kill
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="slow_command", timeout=1)
# Should still raise timeout error even if kill fails
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == "Command timed out after 1s: 'slow_command'"
class TestAcpBashEmbedding:
@pytest.mark.asyncio
async def test_run_with_embedding(self, mock_connection: MockConnection) -> None:
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
await tool.run(args)
assert mock_connection._session_update_called
@pytest.mark.asyncio
async def test_run_embedding_without_tool_call_id(
self, mock_connection: MockConnection
) -> None:
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id=None,
),
)
args = BashArgs(command="test")
await tool.run(args)
# Embedding should be skipped when tool_call_id is None
assert not mock_connection._session_update_called
@pytest.mark.asyncio
async def test_run_embedding_handles_exception(
self, mock_connection: MockConnection
) -> None:
# Make sessionUpdate raise an exception
async def failing_session_update(notification) -> None:
raise RuntimeError("Session update failed")
mock_connection.sessionUpdate = failing_session_update
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
# Should not raise, embedding failure is silently ignored
result = await tool.run(args)
assert result is not None
assert result.stdout == "test output"
class TestAcpBashConfig:
@pytest.mark.asyncio
async def test_run_uses_config_default_timeout(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(
terminal_id="config_timeout_terminal",
wait_delay=0.01, # Shorter than config timeout
)
mock_connection._terminal_handle = custom_handle
tool = Bash(
config=BashToolConfig(default_timeout=30),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="fast", timeout=None)
result = await tool.run(args)
# Should succeed with config timeout
assert result.returncode == 0
class TestAcpBashCleanup:
@pytest.mark.asyncio
async def test_run_releases_terminal_on_success(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(terminal_id="cleanup_terminal")
mock_connection._terminal_handle = custom_handle
release_called = False
async def mock_release() -> None:
nonlocal release_called
release_called = True
custom_handle.release = mock_release
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
await tool.run(args)
assert release_called
@pytest.mark.asyncio
async def test_run_releases_terminal_on_timeout(
self, mock_connection: MockConnection
) -> None:
# The handle will wait 2 seconds, but timeout is 1 second,
# so asyncio.wait_for() will raise TimeoutError
custom_handle = MockTerminalHandle(
terminal_id="timeout_cleanup_terminal",
wait_delay=2.0, # Longer than the 1 second timeout
)
mock_connection._terminal_handle = custom_handle
release_called = False
async def mock_release() -> None:
nonlocal release_called
release_called = True
custom_handle.release = mock_release
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="slow", timeout=1)
# Timeout raises an error, but terminal should still be released
try:
await tool.run(args)
except ToolError:
pass
assert release_called
@pytest.mark.asyncio
async def test_run_handles_release_failure(
self, mock_connection: MockConnection
) -> None:
custom_handle = MockTerminalHandle(terminal_id="release_failure_terminal")
async def failing_release() -> None:
raise RuntimeError("Release failed")
custom_handle.release = failing_release
mock_connection._terminal_handle = custom_handle
tool = Bash(
config=BashToolConfig(),
state=AcpBashState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = BashArgs(command="test")
# Should not raise, release failure is silently ignored
result = await tool.run(args)
assert result is not None
assert result.stdout == "test output"

184
tests/acp/test_content.py Normal file
View File

@@ -0,0 +1,184 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from acp import AgentSideConnection, NewSessionRequest, PromptRequest
from acp.schema import (
EmbeddedResourceContentBlock,
ResourceContentBlock,
TextContentBlock,
TextResourceContents,
)
import pytest
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_connection import FakeAgentSideConnection
from vibe.acp.acp_agent import VibeAcpAgent
from vibe.core.agent import Agent
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
@pytest.fixture
def backend() -> FakeBackend:
backend = FakeBackend(
results=[
LLMChunk(
message=LLMMessage(role=Role.assistant, content="Hi"),
finish_reason="end_turn",
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
)
]
)
return backend
@pytest.fixture
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
class PatchedAgent(Agent):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs, backend=backend)
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
vibe_acp_agent: VibeAcpAgent | None = None
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
nonlocal vibe_acp_agent
vibe_acp_agent = VibeAcpAgent(connection)
return vibe_acp_agent
FakeAgentSideConnection(_create_agent)
return vibe_acp_agent # pyright: ignore[reportReturnType]
class TestACPContent:
@pytest.mark.asyncio
async def test_text_content(
self, acp_agent: VibeAcpAgent, backend: FakeBackend
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
prompt_request = PromptRequest(
prompt=[TextContentBlock(type="text", text="Say hi")],
sessionId=session_response.sessionId,
)
response = await acp_agent.prompt(params=prompt_request)
assert response.stopReason == "end_turn"
user_message = next(
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
None,
)
assert user_message is not None, "User message not found in backend requests"
assert user_message.content == "Say hi"
@pytest.mark.asyncio
async def test_resource_content(
self, acp_agent: VibeAcpAgent, backend: FakeBackend
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
prompt_request = PromptRequest(
prompt=[
TextContentBlock(type="text", text="What does this file do?"),
EmbeddedResourceContentBlock(
type="resource",
resource=TextResourceContents(
uri="file:///home/my_file.py",
text="def hello():\n print('Hello, world!')",
mimeType="text/x-python",
),
),
],
sessionId=session_response.sessionId,
)
response = await acp_agent.prompt(params=prompt_request)
assert response.stopReason == "end_turn"
user_message = next(
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
None,
)
assert user_message is not None, "User message not found in backend requests"
expected_content = (
"What does this file do?"
+ "\n\npath: file:///home/my_file.py"
+ "\ncontent: def hello():\n print('Hello, world!')"
)
assert user_message.content == expected_content
@pytest.mark.asyncio
async def test_resource_link_content(
self, acp_agent: VibeAcpAgent, backend: FakeBackend
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
prompt_request = PromptRequest(
prompt=[
TextContentBlock(type="text", text="Analyze this resource"),
ResourceContentBlock(
type="resource_link",
uri="file:///home/document.pdf",
name="document.pdf",
title="Important Document",
description="A PDF document containing project specifications",
mimeType="application/pdf",
size=1024,
),
],
sessionId=session_response.sessionId,
)
response = await acp_agent.prompt(params=prompt_request)
assert response.stopReason == "end_turn"
user_message = next(
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
None,
)
assert user_message is not None, "User message not found in backend requests"
expected_content = (
"Analyze this resource"
+ "\n\nuri: file:///home/document.pdf"
+ "\nname: document.pdf"
+ "\ntitle: Important Document"
+ "\ndescription: A PDF document containing project specifications"
+ "\nmimeType: application/pdf"
+ "\nsize: 1024"
)
assert user_message.content == expected_content
@pytest.mark.asyncio
async def test_resource_link_minimal(
self, acp_agent: VibeAcpAgent, backend: FakeBackend
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
prompt_request = PromptRequest(
prompt=[
ResourceContentBlock(
type="resource_link",
uri="file:///home/minimal.txt",
name="minimal.txt",
)
],
sessionId=session_response.sessionId,
)
response = await acp_agent.prompt(params=prompt_request)
assert response.stopReason == "end_turn"
user_message = next(
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
None,
)
assert user_message is not None, "User message not found in backend requests"
expected_content = "uri: file:///home/minimal.txt\nname: minimal.txt"
assert user_message.content == expected_content

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from unittest.mock import patch
from uuid import uuid4
from acp import (
PROTOCOL_VERSION,
InitializeRequest,
NewSessionRequest,
PromptRequest,
RequestError,
)
from acp.schema import TextContentBlock
import pytest
from pytest import raises
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_connection import FakeAgentSideConnection
from vibe.acp.acp_agent import VibeAcpAgent
from vibe.core.agent import Agent
from vibe.core.config import ModelConfig, VibeConfig
from vibe.core.types import Role
@pytest.fixture
def backend() -> FakeBackend:
backend = FakeBackend()
return backend
@pytest.fixture
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
config = VibeConfig(
active_model="devstral-latest",
models=[
ModelConfig(
name="devstral-latest", provider="mistral", alias="devstral-latest"
)
],
)
class PatchedAgent(Agent):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.backend = backend
self.config = config
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
vibe_acp_agent: VibeAcpAgent | None = None
def _create_agent(connection: Any) -> VibeAcpAgent:
nonlocal vibe_acp_agent
vibe_acp_agent = VibeAcpAgent(connection)
return vibe_acp_agent
FakeAgentSideConnection(_create_agent)
return vibe_acp_agent # pyright: ignore[reportReturnType]
class TestMultiSessionCore:
@pytest.mark.asyncio
async def test_different_sessions_use_different_agents(
self, acp_agent: VibeAcpAgent
) -> None:
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
session1_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session1 = acp_agent.sessions[session1_response.sessionId]
session2_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session2 = acp_agent.sessions[session2_response.sessionId]
assert session1.id != session2.id
# Each agent should be independent
assert session1.agent is not session2.agent
assert id(session1.agent) != id(session2.agent)
@pytest.mark.asyncio
async def test_error_on_nonexistent_session(self, acp_agent: VibeAcpAgent) -> None:
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
fake_session_id = "fake-session-id-" + str(uuid4())
with raises(RequestError) as exc_info:
await acp_agent.prompt(
PromptRequest(
sessionId=fake_session_id,
prompt=[TextContentBlock(type="text", text="Hello, world!")],
)
)
assert isinstance(exc_info.value, RequestError)
assert str(exc_info.value) == "Invalid params"
@pytest.mark.asyncio
async def test_simultaneous_message_processing(
self, acp_agent: VibeAcpAgent, backend: FakeBackend
) -> None:
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
session1_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session1 = acp_agent.sessions[session1_response.sessionId]
session2_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session2 = acp_agent.sessions[session2_response.sessionId]
backend._chunks = [
mock_llm_chunk(content="Response 1", finish_reason="stop"),
mock_llm_chunk(content="Response 2", finish_reason="stop"),
]
async def run_session1():
await acp_agent.prompt(
PromptRequest(
sessionId=session1.id,
prompt=[TextContentBlock(type="text", text="Prompt for session 1")],
)
)
async def run_session2():
await acp_agent.prompt(
PromptRequest(
sessionId=session2.id,
prompt=[TextContentBlock(type="text", text="Prompt for session 2")],
)
)
await asyncio.gather(run_session1(), run_session2())
user_message1 = next(
(msg for msg in session1.agent.messages if msg.role == Role.user), None
)
assert user_message1 is not None
assert user_message1.content == "Prompt for session 1"
assistant_message1 = next(
(msg for msg in session1.agent.messages if msg.role == Role.assistant), None
)
assert assistant_message1 is not None
assert assistant_message1.content == "Response 1"
user_message2 = next(
(msg for msg in session2.agent.messages if msg.role == Role.user), None
)
assert user_message2 is not None
assert user_message2.content == "Prompt for session 2"
assistant_message2 = next(
(msg for msg in session2.agent.messages if msg.role == Role.assistant), None
)
assert assistant_message2 is not None
assert assistant_message2.content == "Response 2"

View File

@@ -0,0 +1,140 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from acp import AgentSideConnection, NewSessionRequest, SetSessionModelRequest
import pytest
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_connection import FakeAgentSideConnection
from vibe.acp.acp_agent import VibeAcpAgent
from vibe.acp.utils import VibeSessionMode
from vibe.core.agent import Agent
from vibe.core.config import ModelConfig, VibeConfig
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
@pytest.fixture
def backend() -> FakeBackend:
backend = FakeBackend(
results=[
LLMChunk(
message=LLMMessage(role=Role.assistant, content="Hi"),
finish_reason="end_turn",
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
)
]
)
return backend
@pytest.fixture
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
config = VibeConfig(
active_model="devstral-latest",
models=[
ModelConfig(
name="devstral-latest", provider="mistral", alias="devstral-latest"
),
ModelConfig(
name="devstral-small", provider="mistral", alias="devstral-small"
),
],
)
class PatchedAgent(Agent):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **{**kwargs, "backend": backend})
self.config = config
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
vibe_acp_agent: VibeAcpAgent | None = None
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
nonlocal vibe_acp_agent
vibe_acp_agent = VibeAcpAgent(connection)
return vibe_acp_agent
FakeAgentSideConnection(_create_agent)
return vibe_acp_agent # pyright: ignore[reportReturnType]
class TestACPNewSession:
@pytest.mark.asyncio
async def test_new_session_response_structure(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
assert session_response.sessionId is not None
acp_session = next(
(
s
for s in acp_agent.sessions.values()
if s.id == session_response.sessionId
),
None,
)
assert acp_session is not None
assert (
acp_session.agent.interaction_logger.session_id
== session_response.sessionId
)
assert session_response.sessionId == acp_session.agent.session_id
assert session_response.models is not None
assert session_response.models.currentModelId is not None
assert session_response.models.availableModels is not None
assert len(session_response.models.availableModels) == 2
assert session_response.models.currentModelId == "devstral-latest"
assert session_response.models.availableModels[0].modelId == "devstral-latest"
assert session_response.models.availableModels[0].name == "devstral-latest"
assert session_response.models.availableModels[1].modelId == "devstral-small"
assert session_response.models.availableModels[1].name == "devstral-small"
assert session_response.modes is not None
assert session_response.modes.currentModeId is not None
assert session_response.modes.availableModes is not None
assert len(session_response.modes.availableModes) == 2
assert session_response.modes.currentModeId == VibeSessionMode.APPROVAL_REQUIRED
assert (
session_response.modes.availableModes[0].id
== VibeSessionMode.APPROVAL_REQUIRED
)
assert session_response.modes.availableModes[0].name == "Approval Required"
assert (
session_response.modes.availableModes[1].id == VibeSessionMode.AUTO_APPROVE
)
assert session_response.modes.availableModes[1].name == "Auto Approve"
@pytest.mark.skip(reason="TODO: Fix this test")
@pytest.mark.asyncio
async def test_new_session_preserves_model_after_set_model(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
assert session_response.models is not None
assert session_response.models.currentModelId == "devstral-latest"
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
assert session_response.models is not None
assert session_response.models.currentModelId == "devstral-small"

240
tests/acp/test_read_file.py Normal file
View File

@@ -0,0 +1,240 @@
from __future__ import annotations
from pathlib import Path
from acp import ReadTextFileRequest, ReadTextFileResponse
import pytest
from vibe.acp.tools.builtins.read_file import AcpReadFileState, ReadFile
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.read_file import (
ReadFileArgs,
ReadFileResult,
ReadFileToolConfig,
)
class MockConnection:
def __init__(
self,
file_content: str = "line 1\nline 2\nline 3",
read_error: Exception | None = None,
) -> None:
self._file_content = file_content
self._read_error = read_error
self._read_text_file_called = False
self._session_update_called = False
self._last_read_request: ReadTextFileRequest | None = None
async def readTextFile(self, request: ReadTextFileRequest) -> ReadTextFileResponse:
self._read_text_file_called = True
self._last_read_request = request
if self._read_error:
raise self._read_error
content = self._file_content
if request.line is not None or request.limit is not None:
lines = content.splitlines(keepends=True)
start_line = (request.line or 1) - 1 # Convert to 0-indexed
end_line = (
start_line + request.limit if request.limit is not None else len(lines)
)
lines = lines[start_line:end_line]
content = "".join(lines)
return ReadTextFileResponse(content=content)
async def sessionUpdate(self, notification) -> None:
self._session_update_called = True
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
def acp_read_file_tool(mock_connection: MockConnection, tmp_path: Path) -> ReadFile:
config = ReadFileToolConfig(workdir=tmp_path)
state = AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session_123",
tool_call_id="test_tool_call_456",
)
return ReadFile(config=config, state=state)
class TestAcpReadFileBasic:
def test_get_name(self) -> None:
assert ReadFile.get_name() == "read_file"
class TestAcpReadFileExecution:
@pytest.mark.asyncio
async def test_run_success(
self,
acp_read_file_tool: ReadFile,
mock_connection: MockConnection,
tmp_path: Path,
) -> None:
test_file = tmp_path / "test_file.txt"
test_file.touch()
args = ReadFileArgs(path=str(test_file))
result = await acp_read_file_tool.run(args)
assert isinstance(result, ReadFileResult)
assert result.path == str(test_file)
assert result.content == "line 1\nline 2\nline 3"
assert result.lines_read == 3
assert mock_connection._read_text_file_called
assert mock_connection._session_update_called
# Verify ReadTextFileRequest was created correctly
request = mock_connection._last_read_request
assert request is not None
assert request.sessionId == "test_session_123"
assert request.path == str(test_file)
assert request.line is None # offset=0 means no line specified
assert request.limit is None
@pytest.mark.asyncio
async def test_run_with_offset(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
test_file = tmp_path / "test_file.txt"
test_file.touch()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = ReadFileArgs(path=str(test_file), offset=1)
result = await tool.run(args)
assert result.lines_read == 2
assert result.content == "line 2\nline 3"
request = mock_connection._last_read_request
assert request is not None
assert request.line == 2 # offset=1 means line 2 (1-indexed)
@pytest.mark.asyncio
async def test_run_with_limit(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
test_file = tmp_path / "test_file.txt"
test_file.touch()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = ReadFileArgs(path=str(test_file), limit=2)
result = await tool.run(args)
assert result.lines_read == 2
assert result.content == "line 1\nline 2\n"
request = mock_connection._last_read_request
assert request is not None
assert request.limit == 2
@pytest.mark.asyncio
async def test_run_with_offset_and_limit(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
test_file = tmp_path / "test_file.txt"
test_file.touch()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = ReadFileArgs(path=str(test_file), offset=1, limit=1)
result = await tool.run(args)
assert result.lines_read == 1
assert result.content == "line 2\n"
request = mock_connection._last_read_request
assert request is not None
assert request.line == 2
assert request.limit == 1
@pytest.mark.asyncio
async def test_run_read_error(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
mock_connection._read_error = RuntimeError("File not found")
test_file = tmp_path / "test.txt"
test_file.touch()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
args = ReadFileArgs(path=str(test_file))
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == f"Error reading {test_file}: File not found"
@pytest.mark.asyncio
async def test_run_without_connection(self, tmp_path: Path) -> None:
test_file = tmp_path / "test.txt"
test_file.touch()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=None, session_id="test_session", tool_call_id="test_call"
),
)
args = ReadFileArgs(path=str(test_file))
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Connection not available in tool state. This tool can only be used within an ACP session."
)
@pytest.mark.asyncio
async def test_run_without_session_id(self, tmp_path: Path) -> None:
test_file = tmp_path / "test.txt"
test_file.touch()
mock_connection = MockConnection()
tool = ReadFile(
config=ReadFileToolConfig(workdir=tmp_path),
state=AcpReadFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id=None,
tool_call_id="test_call",
),
)
args = ReadFileArgs(path=str(test_file))
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Session ID not available in tool state. This tool can only be used within an ACP session."
)

View File

@@ -0,0 +1,339 @@
from __future__ import annotations
from pathlib import Path
from acp import ReadTextFileRequest, ReadTextFileResponse, WriteTextFileRequest
import pytest
from vibe.acp.tools.builtins.search_replace import AcpSearchReplaceState, SearchReplace
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.search_replace import (
SearchReplaceArgs,
SearchReplaceConfig,
SearchReplaceResult,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
class MockConnection:
def __init__(
self,
file_content: str = "original line 1\noriginal line 2\noriginal line 3",
read_error: Exception | None = None,
write_error: Exception | None = None,
) -> None:
self._file_content = file_content
self._read_error = read_error
self._write_error = write_error
self._read_text_file_called = False
self._write_text_file_called = False
self._session_update_called = False
self._last_read_request: ReadTextFileRequest | None = None
self._last_write_request: WriteTextFileRequest | None = None
self._write_calls: list[WriteTextFileRequest] = []
async def readTextFile(self, request: ReadTextFileRequest) -> ReadTextFileResponse:
self._read_text_file_called = True
self._last_read_request = request
if self._read_error:
raise self._read_error
return ReadTextFileResponse(content=self._file_content)
async def writeTextFile(self, request: WriteTextFileRequest) -> None:
self._write_text_file_called = True
self._last_write_request = request
self._write_calls.append(request)
if self._write_error:
raise self._write_error
async def sessionUpdate(self, notification) -> None:
self._session_update_called = True
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
def acp_search_replace_tool(
mock_connection: MockConnection, tmp_path: Path
) -> SearchReplace:
config = SearchReplaceConfig(workdir=tmp_path)
state = AcpSearchReplaceState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session_123",
tool_call_id="test_tool_call_456",
)
return SearchReplace(config=config, state=state)
class TestAcpSearchReplaceBasic:
def test_get_name(self) -> None:
assert SearchReplace.get_name() == "search_replace"
class TestAcpSearchReplaceExecution:
@pytest.mark.asyncio
async def test_run_success(
self,
acp_search_replace_tool: SearchReplace,
mock_connection: MockConnection,
tmp_path: Path,
) -> None:
test_file = tmp_path / "test_file.txt"
test_file.write_text("original line 1\noriginal line 2\noriginal line 3")
search_replace_content = (
"<<<<<<< SEARCH\noriginal line 2\n=======\nmodified line 2\n>>>>>>> REPLACE"
)
args = SearchReplaceArgs(
file_path=str(test_file), content=search_replace_content
)
result = await acp_search_replace_tool.run(args)
assert isinstance(result, SearchReplaceResult)
assert result.file == str(test_file)
assert result.blocks_applied == 1
assert mock_connection._read_text_file_called
assert mock_connection._write_text_file_called
assert mock_connection._session_update_called
# Verify ReadTextFileRequest was created correctly
read_request = mock_connection._last_read_request
assert read_request is not None
assert read_request.sessionId == "test_session_123"
assert read_request.path == str(test_file)
# Verify WriteTextFileRequest was created correctly
write_request = mock_connection._last_write_request
assert write_request is not None
assert write_request.sessionId == "test_session_123"
assert write_request.path == str(test_file)
assert (
write_request.content == "original line 1\nmodified line 2\noriginal line 3"
)
@pytest.mark.asyncio
async def test_run_with_backup(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
config = SearchReplaceConfig(create_backup=True, workdir=tmp_path)
tool = SearchReplace(
config=config,
state=AcpSearchReplaceState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
test_file = tmp_path / "test_file.txt"
test_file.write_text("original line 1\noriginal line 2\noriginal line 3")
search_replace_content = (
"<<<<<<< SEARCH\noriginal line 1\n=======\nmodified line 1\n>>>>>>> REPLACE"
)
args = SearchReplaceArgs(
file_path=str(test_file), content=search_replace_content
)
result = await tool.run(args)
assert result.blocks_applied == 1
# Should have written the main file and the backup
assert len(mock_connection._write_calls) >= 1
# Check if backup was written (it should be written to .bak file)
assert sum(w.path.endswith(".bak") for w in mock_connection._write_calls) == 1
@pytest.mark.asyncio
async def test_run_read_error(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
mock_connection._read_error = RuntimeError("File not found")
tool = SearchReplace(
config=SearchReplaceConfig(workdir=tmp_path),
state=AcpSearchReplaceState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
test_file = tmp_path / "test.txt"
test_file.touch()
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
args = SearchReplaceArgs(
file_path=str(test_file), content=search_replace_content
)
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== f"Unexpected error reading {test_file}: File not found"
)
@pytest.mark.asyncio
async def test_run_write_error(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
mock_connection._write_error = RuntimeError("Permission denied")
test_file = tmp_path / "test.txt"
test_file.touch()
mock_connection._file_content = "old" # Update mock to return correct content
tool = SearchReplace(
config=SearchReplaceConfig(workdir=tmp_path),
state=AcpSearchReplaceState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
args = SearchReplaceArgs(
file_path=str(test_file), content=search_replace_content
)
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == f"Error writing {test_file}: Permission denied"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"connection,session_id,expected_error",
[
(
None,
"test_session",
"Connection not available in tool state. This tool can only be used within an ACP session.",
),
(
MockConnection(),
None,
"Session ID not available in tool state. This tool can only be used within an ACP session.",
),
],
)
async def test_run_without_required_state(
self,
tmp_path: Path,
connection: MockConnection | None,
session_id: str | None,
expected_error: str,
) -> None:
test_file = tmp_path / "test.txt"
test_file.touch()
tool = SearchReplace(
config=SearchReplaceConfig(workdir=tmp_path),
state=AcpSearchReplaceState.model_construct(
connection=connection, # type: ignore[arg-type]
session_id=session_id,
tool_call_id="test_call",
),
)
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
args = SearchReplaceArgs(
file_path=str(test_file), content=search_replace_content
)
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == expected_error
class TestAcpSearchReplaceSessionUpdates:
def test_tool_call_session_update(self) -> None:
search_replace_content = (
"<<<<<<< SEARCH\nold text\n=======\nnew text\n>>>>>>> REPLACE"
)
event = ToolCallEvent(
tool_name="search_replace",
tool_call_id="test_call_123",
args=SearchReplaceArgs(
file_path="/tmp/test.txt", content=search_replace_content
),
tool_class=SearchReplace,
)
update = SearchReplace.tool_call_session_update(event)
assert update is not None
assert update.sessionUpdate == "tool_call"
assert update.toolCallId == "test_call_123"
assert update.kind == "edit"
assert update.title is not None
assert update.content is not None
assert len(update.content) == 1
assert update.content[0].type == "diff"
assert update.content[0].path == "/tmp/test.txt"
assert update.content[0].oldText == "old text"
assert update.content[0].newText == "new text"
assert update.locations is not None
assert len(update.locations) == 1
assert update.locations[0].path == "/tmp/test.txt"
def test_tool_call_session_update_invalid_args(self) -> None:
class InvalidArgs:
pass
event = ToolCallEvent.model_construct(
tool_name="search_replace",
tool_call_id="test_call_123",
args=InvalidArgs(), # type: ignore[arg-type]
tool_class=SearchReplace,
)
update = SearchReplace.tool_call_session_update(event)
assert update is None
def test_tool_result_session_update(self) -> None:
search_replace_content = (
"<<<<<<< SEARCH\nold text\n=======\nnew text\n>>>>>>> REPLACE"
)
result = SearchReplaceResult(
file="/tmp/test.txt",
blocks_applied=1,
lines_changed=1,
content=search_replace_content,
warnings=[],
)
event = ToolResultEvent(
tool_name="search_replace",
tool_call_id="test_call_123",
result=result,
tool_class=SearchReplace,
)
update = SearchReplace.tool_result_session_update(event)
assert update is not None
assert update.sessionUpdate == "tool_call_update"
assert update.toolCallId == "test_call_123"
assert update.status == "completed"
assert update.content is not None
assert len(update.content) == 1
assert update.content[0].type == "diff"
assert update.content[0].path == "/tmp/test.txt"
assert update.content[0].oldText == "old text"
assert update.content[0].newText == "new text"
assert update.locations is not None
assert len(update.locations) == 1
assert update.locations[0].path == "/tmp/test.txt"
def test_tool_result_session_update_invalid_result(self) -> None:
class InvalidResult:
pass
event = ToolResultEvent.model_construct(
tool_name="search_replace",
tool_call_id="test_call_123",
result=InvalidResult(), # type: ignore[arg-type]
tool_class=SearchReplace,
)
update = SearchReplace.tool_result_session_update(event)
assert update is None

165
tests/acp/test_set_mode.py Normal file
View File

@@ -0,0 +1,165 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from acp import AgentSideConnection, NewSessionRequest, SetSessionModeRequest
import pytest
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_connection import FakeAgentSideConnection
from vibe.acp.acp_agent import VibeAcpAgent
from vibe.acp.utils import VibeSessionMode
from vibe.core.agent import Agent
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
@pytest.fixture
def backend() -> FakeBackend:
backend = FakeBackend(
results=[
LLMChunk(
message=LLMMessage(role=Role.assistant, content="Hi"),
finish_reason="end_turn",
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
)
]
)
return backend
@pytest.fixture
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
class PatchedAgent(Agent):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs, backend=backend)
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
vibe_acp_agent: VibeAcpAgent | None = None
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
nonlocal vibe_acp_agent
vibe_acp_agent = VibeAcpAgent(connection)
return vibe_acp_agent
FakeAgentSideConnection(_create_agent)
return vibe_acp_agent # pyright: ignore[reportReturnType]
class TestACPSetMode:
@pytest.mark.asyncio
async def test_set_mode_to_approval_required(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
acp_session.agent.auto_approve = True
acp_session.mode_id = VibeSessionMode.AUTO_APPROVE
response = await acp_agent.setSessionMode(
SetSessionModeRequest(
sessionId=session_id, modeId=VibeSessionMode.APPROVAL_REQUIRED
)
)
assert response is not None
assert acp_session.mode_id == VibeSessionMode.APPROVAL_REQUIRED
assert acp_session.agent.auto_approve is False
@pytest.mark.asyncio
async def test_set_mode_to_AUTO_APPROVE(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
assert acp_session.mode_id == VibeSessionMode.APPROVAL_REQUIRED
assert acp_session.agent.auto_approve is False
response = await acp_agent.setSessionMode(
SetSessionModeRequest(
sessionId=session_id, modeId=VibeSessionMode.AUTO_APPROVE
)
)
assert response is not None
assert acp_session.mode_id == VibeSessionMode.AUTO_APPROVE
assert acp_session.agent.auto_approve is True
@pytest.mark.asyncio
async def test_set_mode_invalid_mode_returns_none(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_mode_id = acp_session.mode_id
initial_auto_approve = acp_session.agent.auto_approve
response = await acp_agent.setSessionMode(
SetSessionModeRequest(sessionId=session_id, modeId="invalid-mode")
)
assert response is None
assert acp_session.mode_id == initial_mode_id
assert acp_session.agent.auto_approve == initial_auto_approve
@pytest.mark.asyncio
async def test_set_mode_to_same_mode(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_mode_id = VibeSessionMode.APPROVAL_REQUIRED
assert acp_session.mode_id == initial_mode_id
response = await acp_agent.setSessionMode(
SetSessionModeRequest(sessionId=session_id, modeId=initial_mode_id)
)
assert response is not None
assert acp_session.mode_id == initial_mode_id
assert acp_session.agent.auto_approve is False
@pytest.mark.asyncio
async def test_set_mode_with_empty_string(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_mode_id = acp_session.mode_id
initial_auto_approve = acp_session.agent.auto_approve
response = await acp_agent.setSessionMode(
SetSessionModeRequest(sessionId=session_id, modeId="")
)
assert response is None
assert acp_session.mode_id == initial_mode_id
assert acp_session.agent.auto_approve == initial_auto_approve

308
tests/acp/test_set_model.py Normal file
View File

@@ -0,0 +1,308 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
from acp import AgentSideConnection, NewSessionRequest, SetSessionModelRequest
import pytest
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_connection import FakeAgentSideConnection
from vibe.acp.acp_agent import VibeAcpAgent
from vibe.core.agent import Agent
from vibe.core.config import ModelConfig, VibeConfig
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
@pytest.fixture
def backend() -> FakeBackend:
backend = FakeBackend(
results=[
LLMChunk(
message=LLMMessage(role=Role.assistant, content="Hi"),
finish_reason="end_turn",
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
)
]
)
return backend
@pytest.fixture
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
config = VibeConfig(
active_model="devstral-latest",
models=[
ModelConfig(
name="devstral-latest",
provider="mistral",
alias="devstral-latest",
input_price=0.4,
output_price=2.0,
),
ModelConfig(
name="devstral-small",
provider="mistral",
alias="devstral-small",
input_price=0.1,
output_price=0.3,
),
],
)
class PatchedAgent(Agent):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **{**kwargs, "backend": backend})
self.config = config
try:
active_model = config.get_active_model()
self.stats.input_price_per_million = active_model.input_price
self.stats.output_price_per_million = active_model.output_price
except ValueError:
pass
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
vibe_acp_agent: VibeAcpAgent | None = None
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
nonlocal vibe_acp_agent
vibe_acp_agent = VibeAcpAgent(connection)
return vibe_acp_agent
FakeAgentSideConnection(_create_agent)
return vibe_acp_agent # pyright: ignore[reportReturnType]
class TestACPSetModel:
@pytest.mark.asyncio
async def test_set_model_success(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
assert acp_session.agent.config.active_model == "devstral-latest"
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
assert acp_session.agent.config.active_model == "devstral-small"
@pytest.mark.asyncio
async def test_set_model_invalid_model_returns_none(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_model = acp_session.agent.config.active_model
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="non-existent-model")
)
assert response is None
assert acp_session.agent.config.active_model == initial_model
@pytest.mark.asyncio
async def test_set_model_to_same_model(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
initial_model = "devstral-latest"
assert acp_session is not None
assert acp_session.agent.config.active_model == initial_model
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId=initial_model)
)
assert response is not None
assert acp_session.agent.config.active_model == initial_model
@pytest.mark.asyncio
async def test_set_model_saves_to_config(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
with patch("vibe.acp.acp_agent.VibeConfig.save_updates") as mock_save:
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
mock_save.assert_called_once_with({"active_model": "devstral-small"})
@pytest.mark.asyncio
async def test_set_model_does_not_save_on_invalid_model(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
with patch("vibe.acp.acp_agent.VibeConfig.save_updates") as mock_save:
response = await acp_agent.setSessionModel(
SetSessionModelRequest(
sessionId=session_id, modelId="non-existent-model"
)
)
assert response is None
mock_save.assert_not_called()
@pytest.mark.asyncio
async def test_set_model_with_empty_string(self, acp_agent: VibeAcpAgent) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_model = acp_session.agent.config.active_model
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="")
)
assert response is None
assert acp_session.agent.config.active_model == initial_model
@pytest.mark.asyncio
async def test_set_model_updates_active_model(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
assert acp_session.agent.config.get_active_model().alias == "devstral-latest"
await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert acp_session.agent.config.get_active_model().alias == "devstral-small"
@pytest.mark.asyncio
async def test_set_model_calls_reload_with_initial_messages(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
with patch.object(
acp_session.agent, "reload_with_initial_messages"
) as mock_reload:
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
mock_reload.assert_called_once()
call_args = mock_reload.call_args
assert call_args.kwargs["config"] is not None
assert call_args.kwargs["config"].active_model == "devstral-small"
@pytest.mark.asyncio
async def test_set_model_preserves_conversation_history(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
user_msg = LLMMessage(role=Role.user, content="Hello")
assistant_msg = LLMMessage(role=Role.assistant, content="Hi there!")
acp_session.agent.messages.append(user_msg)
acp_session.agent.messages.append(assistant_msg)
assert len(acp_session.agent.messages) == 3
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
assert len(acp_session.agent.messages) == 3
assert acp_session.agent.messages[0].role == Role.system
assert acp_session.agent.messages[1].content == "Hello"
assert acp_session.agent.messages[2].content == "Hi there!"
@pytest.mark.asyncio
async def test_set_model_resets_stats_with_new_model_pricing(
self, acp_agent: VibeAcpAgent
) -> None:
session_response = await acp_agent.newSession(
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
)
session_id = session_response.sessionId
acp_session = next(
(s for s in acp_agent.sessions.values() if s.id == session_id), None
)
assert acp_session is not None
initial_model = acp_session.agent.config.get_active_model()
initial_input_price = initial_model.input_price
initial_output_price = initial_model.output_price
initial_stats_input = acp_session.agent.stats.input_price_per_million
initial_stats_output = acp_session.agent.stats.output_price_per_million
assert acp_session.agent.stats.input_price_per_million == initial_input_price
assert acp_session.agent.stats.output_price_per_million == initial_output_price
response = await acp_agent.setSessionModel(
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
)
assert response is not None
new_model = acp_session.agent.config.get_active_model()
new_input_price = new_model.input_price
new_output_price = new_model.output_price
assert new_input_price != initial_input_price
assert new_output_price != initial_output_price
assert acp_session.agent.stats.input_price_per_million == new_input_price
assert acp_session.agent.stats.output_price_per_million == new_output_price
assert acp_session.agent.stats.input_price_per_million != initial_stats_input
assert acp_session.agent.stats.output_price_per_million != initial_stats_output

View File

@@ -0,0 +1,269 @@
from __future__ import annotations
from pathlib import Path
from acp import WriteTextFileRequest
import pytest
from vibe.acp.tools.builtins.write_file import AcpWriteFileState, WriteFile
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.write_file import (
WriteFileArgs,
WriteFileConfig,
WriteFileResult,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
class MockConnection:
def __init__(
self, write_error: Exception | None = None, file_exists: bool = False
) -> None:
self._write_error = write_error
self._file_exists = file_exists
self._write_text_file_called = False
self._session_update_called = False
self._last_write_request: WriteTextFileRequest | None = None
async def writeTextFile(self, request: WriteTextFileRequest) -> None:
self._write_text_file_called = True
self._last_write_request = request
if self._write_error:
raise self._write_error
async def sessionUpdate(self, notification) -> None:
self._session_update_called = True
@pytest.fixture
def mock_connection() -> MockConnection:
return MockConnection()
@pytest.fixture
def acp_write_file_tool(mock_connection: MockConnection, tmp_path: Path) -> WriteFile:
config = WriteFileConfig(workdir=tmp_path)
state = AcpWriteFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session_123",
tool_call_id="test_tool_call_456",
)
return WriteFile(config=config, state=state)
class TestAcpWriteFileBasic:
def test_get_name(self) -> None:
assert WriteFile.get_name() == "write_file"
class TestAcpWriteFileExecution:
@pytest.mark.asyncio
async def test_run_success_new_file(
self,
acp_write_file_tool: WriteFile,
mock_connection: MockConnection,
tmp_path: Path,
) -> None:
test_file = tmp_path / "test_file.txt"
args = WriteFileArgs(path=str(test_file), content="Hello, world!")
result = await acp_write_file_tool.run(args)
assert isinstance(result, WriteFileResult)
assert result.path == str(test_file)
assert result.content == "Hello, world!"
assert result.bytes_written == len(b"Hello, world!")
assert result.file_existed is False
assert mock_connection._write_text_file_called
assert mock_connection._session_update_called
# Verify WriteTextFileRequest was created correctly
request = mock_connection._last_write_request
assert request is not None
assert request.sessionId == "test_session_123"
assert request.path == str(test_file)
assert request.content == "Hello, world!"
@pytest.mark.asyncio
async def test_run_success_overwrite(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
tool = WriteFile(
config=WriteFileConfig(workdir=tmp_path),
state=AcpWriteFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
test_file = tmp_path / "existing_file.txt"
test_file.touch()
# Simulate existing file by checking in the core tool logic
# The ACP tool doesn't check existence, it's handled by the core tool
args = WriteFileArgs(path=str(test_file), content="New content", overwrite=True)
result = await tool.run(args)
assert isinstance(result, WriteFileResult)
assert result.path == str(test_file)
assert result.content == "New content"
assert result.bytes_written == len(b"New content")
assert result.file_existed is True
assert mock_connection._write_text_file_called
assert mock_connection._session_update_called
# Verify WriteTextFileRequest was created correctly
request = mock_connection._last_write_request
assert request is not None
assert request.sessionId == "test_session"
assert request.path == str(test_file)
assert request.content == "New content"
@pytest.mark.asyncio
async def test_run_write_error(
self, mock_connection: MockConnection, tmp_path: Path
) -> None:
mock_connection._write_error = RuntimeError("Permission denied")
tool = WriteFile(
config=WriteFileConfig(workdir=tmp_path),
state=AcpWriteFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id="test_session",
tool_call_id="test_call",
),
)
test_file = tmp_path / "test.txt"
args = WriteFileArgs(path=str(test_file), content="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert str(exc_info.value) == f"Error writing {test_file}: Permission denied"
@pytest.mark.asyncio
async def test_run_without_connection(self, tmp_path: Path) -> None:
tool = WriteFile(
config=WriteFileConfig(workdir=tmp_path),
state=AcpWriteFileState.model_construct(
connection=None, session_id="test_session", tool_call_id="test_call"
),
)
args = WriteFileArgs(path=str(tmp_path / "test.txt"), content="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Connection not available in tool state. This tool can only be used within an ACP session."
)
@pytest.mark.asyncio
async def test_run_without_session_id(self, tmp_path: Path) -> None:
mock_connection = MockConnection()
tool = WriteFile(
config=WriteFileConfig(workdir=tmp_path),
state=AcpWriteFileState.model_construct(
connection=mock_connection, # type: ignore[arg-type]
session_id=None,
tool_call_id="test_call",
),
)
args = WriteFileArgs(path=str(tmp_path / "test.txt"), content="test")
with pytest.raises(ToolError) as exc_info:
await tool.run(args)
assert (
str(exc_info.value)
== "Session ID not available in tool state. This tool can only be used within an ACP session."
)
class TestAcpWriteFileSessionUpdates:
def test_tool_call_session_update(self) -> None:
event = ToolCallEvent(
tool_name="write_file",
tool_call_id="test_call_123",
args=WriteFileArgs(path="/tmp/test.txt", content="Hello"),
tool_class=WriteFile,
)
update = WriteFile.tool_call_session_update(event)
assert update is not None
assert update.sessionUpdate == "tool_call"
assert update.toolCallId == "test_call_123"
assert update.kind == "edit"
assert update.title is not None
assert update.content is not None
assert len(update.content) == 1
assert update.content[0].type == "diff"
assert update.content[0].path == "/tmp/test.txt"
assert update.content[0].oldText is None
assert update.content[0].newText == "Hello"
assert update.locations is not None
assert len(update.locations) == 1
assert update.locations[0].path == "/tmp/test.txt"
def test_tool_call_session_update_invalid_args(self) -> None:
from vibe.core.types import FunctionCall, ToolCall
class InvalidArgs:
pass
event = ToolCallEvent.model_construct(
tool_name="write_file",
tool_call_id="test_call_123",
args=InvalidArgs(), # type: ignore[arg-type]
tool_class=WriteFile,
llm_tool_call=ToolCall(
function=FunctionCall(name="write_file", arguments="{}"),
type="function",
index=0,
),
)
update = WriteFile.tool_call_session_update(event)
assert update is None
def test_tool_result_session_update(self) -> None:
result = WriteFileResult(
path="/tmp/test.txt", content="Hello", bytes_written=5, file_existed=False
)
event = ToolResultEvent(
tool_name="write_file",
tool_call_id="test_call_123",
result=result,
tool_class=WriteFile,
)
update = WriteFile.tool_result_session_update(event)
assert update is not None
assert update.sessionUpdate == "tool_call_update"
assert update.toolCallId == "test_call_123"
assert update.status == "completed"
assert update.content is not None
assert len(update.content) == 1
assert update.content[0].type == "diff"
assert update.content[0].path == "/tmp/test.txt"
assert update.content[0].oldText is None
assert update.content[0].newText == "Hello"
assert update.locations is not None
assert len(update.locations) == 1
assert update.locations[0].path == "/tmp/test.txt"
def test_tool_result_session_update_invalid_result(self) -> None:
class InvalidResult:
pass
event = ToolResultEvent.model_construct(
tool_name="write_file",
tool_call_id="test_call_123",
result=InvalidResult(), # type: ignore[arg-type]
tool_class=WriteFile,
)
update = WriteFile.tool_result_session_update(event)
assert update is None

View File

@@ -0,0 +1,231 @@
from __future__ import annotations
from collections.abc import Callable, Generator
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import time
import pytest
from vibe.core.autocompletion.file_indexer import FileIndexer
# This suite runs against the real filesystem and watcher. A faked store/watcher
# split would be faster to unit-test, but given time constraints and the low churn
# expected for this feature, integration coverage was chosen as a trade-off.
@pytest.fixture
def file_indexer() -> Generator[FileIndexer]:
indexer = FileIndexer()
yield indexer
indexer.shutdown()
def _wait_for(condition: Callable[[], bool], timeout=3.0) -> bool:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if condition():
return True
time.sleep(0.05)
return False
def test_updates_index_on_file_creation(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
file_indexer.get_index(Path("."))
target = tmp_path / "new_file.py"
target.write_text("", encoding="utf-8")
assert _wait_for(
lambda: any(
entry.rel == target.name for entry in file_indexer.get_index(Path("."))
)
)
def test_updates_index_on_file_deletion(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
target = tmp_path / "new_file.py"
target.write_text("", encoding="utf-8")
file_indexer.get_index(Path("."))
target.unlink()
assert _wait_for(
lambda: all(
entry.rel != target.name for entry in file_indexer.get_index(Path("."))
)
)
def test_updates_index_on_file_rename(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
old_file = tmp_path / "old_name.py"
old_file.write_text("", encoding="utf-8")
file_indexer.get_index(Path("."))
new_file = tmp_path / "new_name.py"
old_file.rename(new_file)
assert _wait_for(
lambda: all(
entry.rel != old_file.name for entry in file_indexer.get_index(Path("."))
)
and any(
entry.rel == new_file.name for entry in file_indexer.get_index(Path("."))
)
)
def test_updates_index_on_folder_rename(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
old_folder = tmp_path / "old_folder"
old_folder.mkdir()
number_of_files = 5
file_names = [f"file{i}.py" for i in range(1, number_of_files + 1)]
old_file_paths = [old_folder / name for name in file_names]
for file_path in old_file_paths:
file_path.write_text("", encoding="utf-8")
file_indexer.get_index(Path("."))
new_folder = tmp_path / "new_folder"
old_folder.rename(new_folder)
assert _wait_for(
lambda: (
entries := file_indexer.get_index(Path(".")),
all(not entry.rel.startswith("old_folder/") for entry in entries)
and all(
any(entry.rel == f"new_folder/{name}" for entry in entries)
for name in file_names
),
)[1]
)
def test_updates_index_incrementally_by_default(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
file_indexer.get_index(Path("."))
rebuilds_before = file_indexer.stats.rebuilds
incremental_before = file_indexer.stats.incremental_updates
target = tmp_path / "stats_file.py"
target.write_text("", encoding="utf-8")
assert _wait_for(
lambda: any(
entry.rel == target.name for entry in file_indexer.get_index(Path("."))
)
)
assert file_indexer.stats.rebuilds == rebuilds_before
assert file_indexer.stats.incremental_updates >= incremental_before + 1
def test_rebuilds_index_when_mass_change_threshold_is_exceeded(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
mass_change_threshold = 5
# in an ideal world, we would use "threshold + 1", but in reality, we need to test with a
# number of files important enough to MAKE SURE that a batch of >= threshold events will be
# detected by the watcher
number_of_files = mass_change_threshold * 3
monkeypatch.chdir(tmp_path)
indexer = FileIndexer(mass_change_threshold=mass_change_threshold)
try:
indexer.get_index(Path("."))
rebuilds_before = indexer.stats.rebuilds
ThreadPoolExecutor(max_workers=number_of_files).map(
lambda i: (tmp_path / f"bulk{i}.py").write_text("", encoding="utf-8"),
range(number_of_files),
)
assert _wait_for(lambda: len(indexer.get_index(Path("."))) == number_of_files)
# we do not assert that "incremental_updates" did not change,
# as the watcher potentially reported some batches of events that were
# smaller than the threshold
assert indexer.stats.rebuilds >= rebuilds_before + 1
finally:
indexer.shutdown()
def test_switching_between_roots_restarts_index(
tmp_path: Path,
tmp_path_factory: pytest.TempPathFactory,
monkeypatch: pytest.MonkeyPatch,
file_indexer: FileIndexer,
) -> None:
first_root = tmp_path
second_root = tmp_path_factory.mktemp("second-root")
(first_root / "first.py").write_text("", encoding="utf-8")
(second_root / "second.py").write_text("", encoding="utf-8")
monkeypatch.chdir(first_root)
assert _wait_for(
lambda: any(
entry.rel == "first.py" for entry in file_indexer.get_index(Path("."))
)
)
monkeypatch.chdir(second_root)
assert _wait_for(
lambda: all(
entry.rel != "first.py" for entry in file_indexer.get_index(Path("."))
)
)
assert _wait_for(
lambda: any(
entry.rel == "second.py" for entry in file_indexer.get_index(Path("."))
)
)
def test_watcher_failure_does_not_break_existing_index(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
) -> None:
monkeypatch.chdir(tmp_path)
seed = tmp_path / "seed.py"
seed.write_text("", encoding="utf-8")
file_indexer.get_index(Path("."))
def boom(*_: object, **__: object) -> None:
raise RuntimeError("boom")
monkeypatch.setattr(file_indexer._store, "apply_changes", boom)
(tmp_path / "new_file.py").write_text("", encoding="utf-8")
assert _wait_for(
lambda: (
entries := file_indexer.get_index(Path(".")),
# new file was not added: watcher failed
all(entry.rel != "new_file.py" for entry in entries)
# but the existing index is still intact
and all(entry.rel == "seed.py" for entry in entries),
)[1]
)
def test_shutdown_cleans_up_resources(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.chdir(tmp_path)
(tmp_path / "test.txt").write_text("", encoding="utf-8")
file_indexer = FileIndexer()
file_indexer.get_index(Path("."))
file_indexer.shutdown()
assert file_indexer.get_index(Path(".")) == []

View File

@@ -0,0 +1,96 @@
from __future__ import annotations
from vibe.core.autocompletion.fuzzy import fuzzy_match
def test_empty_pattern_matches_anything() -> None:
result = fuzzy_match("", "any_text")
assert result.matched is True
assert result.score == 0.0
assert result.matched_indices == ()
def test_matches_exact_prefix() -> None:
result = fuzzy_match("src/", "src/main.py")
assert result.matched_indices == (0, 1, 2, 3)
def test_no_match_when_characters_are_out_of_order() -> None:
result = fuzzy_match("ms", "src/main.py")
assert result.matched is False
def test_treats_consecutive_characters_as_subsequence() -> None:
result = fuzzy_match("main", "src/main.py")
assert result.matched_indices == (4, 5, 6, 7)
def test_ignores_case() -> None:
result = fuzzy_match("SRC", "src/main.py")
assert result.matched_indices == (0, 1, 2)
def test_treats_scattered_characters_as_subsequence() -> None:
result = fuzzy_match("sm", "src/main.py")
assert result.matched_indices == (0, 4)
def test_treats_path_separator_as_word_boundary() -> None:
result = fuzzy_match("m", "src/main.py")
assert result.matched_indices == (4,)
def test_prefers_word_boundary_matching_over_subsequence() -> None:
boundary_result = fuzzy_match("ma", "src/main.py")
subsequence_result = fuzzy_match("ma", "src/important.py")
assert boundary_result.score > subsequence_result.score
def test_scores_exact_prefix_match_higher_than_consecutive_and_subsequence() -> None:
prefix_result = fuzzy_match("src", "src/main.py")
consecutive_result = fuzzy_match("main", "src/main.py")
subsequence_result = fuzzy_match("sm", "src/main.py")
assert prefix_result.matched_indices == (0, 1, 2)
assert prefix_result.score > consecutive_result.score
assert prefix_result.score > subsequence_result.score
def test_finds_no_match_when_pattern_is_longer_than_entry() -> None:
result = fuzzy_match("very_long_pattern", "short")
assert result.matched is False
def test_prefers_consecutive_match_over_subsequence() -> None:
consecutive = fuzzy_match("main", "src/main.py")
subsequence = fuzzy_match("mn", "src/main.py")
assert consecutive.score > subsequence.score
def test_prefers_case_sensitive_match_over_case_insensitive() -> None:
case_match = fuzzy_match("Main", "src/Main.py")
case_insensitive_match = fuzzy_match("main", "src/Main.py")
assert case_match.score > case_insensitive_match.score
def test_treats_uppercase_letter_as_word_boundary() -> None:
result = fuzzy_match("MP", "src/MainPy.py")
assert result.matched_indices == (4, 8)
def test_favors_earlier_positions() -> None:
result = fuzzy_match("a", "banana")
assert result.matched_indices == (1,)

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
from pathlib import Path
import pytest
from vibe.core.autocompletion.completers import PathCompleter
@pytest.fixture()
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
(tmp_path / "src" / "utils").mkdir(parents=True)
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "models.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core").mkdir(parents=True)
(tmp_path / "src" / "core" / "logger.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "models.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "ports.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "sanitize.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "use_cases.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "validate.py").write_text("", encoding="utf-8")
(tmp_path / "README.md").write_text("", encoding="utf-8")
(tmp_path / ".env").write_text("", encoding="utf-8")
(tmp_path / "config").mkdir(parents=True)
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
(tmp_path / "config" / "database.py").write_text("", encoding="utf-8")
monkeypatch.chdir(tmp_path)
return tmp_path
def test_fuzzy_matches_subsequence_characters(file_tree: Path) -> None:
results = PathCompleter().get_completions("@sr", cursor_pos=3)
assert "@src/" in results
def test_fuzzy_matches_consecutive_characters_higher(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
assert "@src/main.py" in results
def test_fuzzy_matches_prefix_highest(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src", cursor_pos=4)
assert results[0].startswith("@src")
def test_fuzzy_matches_across_directory_boundaries(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
assert "@src/main.py" in results
def test_fuzzy_matches_case_insensitive(file_tree: Path) -> None:
completer = PathCompleter()
assert "@README.md" in completer.get_completions("@readme", cursor_pos=7)
assert "@README.md" in completer.get_completions("@README", cursor_pos=7)
def test_fuzzy_matches_word_boundaries_preferred(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/mp", cursor_pos=7)
assert "@src/models.py" in results
def test_fuzzy_matches_empty_pattern_shows_all(file_tree: Path) -> None:
results = PathCompleter().get_completions("@", cursor_pos=1)
assert "@README.md" in results
assert "@src/" in results
def test_fuzzy_matches_hidden_files_only_with_dot(file_tree: Path) -> None:
completer = PathCompleter()
assert "@.env" not in completer.get_completions("@e", cursor_pos=2)
assert "@.env" in completer.get_completions("@.", cursor_pos=2)
def test_fuzzy_matches_directories_and_files(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/", cursor_pos=5)
assert any(r.endswith("/") for r in results)
assert any(not r.endswith("/") for r in results)
def test_fuzzy_matches_sorted_by_score(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
assert results[0] == "@src/main.py"
def test_fuzzy_matches_nested_directories(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/core/l", cursor_pos=11)
assert "@src/core/logger.py" in results
def test_fuzzy_matches_partial_filename(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/mo", cursor_pos=7)
assert "@src/models.py" in results
def test_fuzzy_matches_multiple_files_with_same_pattern(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/m", cursor_pos=6)
assert "@src/main.py" in results
assert "@src/models.py" in results
def test_fuzzy_matches_no_results_when_no_match(file_tree: Path) -> None:
completer = PathCompleter()
assert completer.get_completions("@xyz123", cursor_pos=7) == []
def test_fuzzy_matches_directory_traversal(file_tree: Path) -> None:
results = PathCompleter().get_completions("@src/", cursor_pos=5)
assert "@src/main.py" in results
assert "@src/core/" in results
assert "@src/utils/" in results

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from pathlib import Path
import pytest
from vibe.core.autocompletion.completers import PathCompleter
@pytest.fixture()
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
(tmp_path / "vibe" / "acp").mkdir(parents=True)
(tmp_path / "vibe" / "acp" / "entrypoint.py").write_text("")
(tmp_path / "vibe" / "acp" / "agent.py").write_text("")
(tmp_path / "vibe" / "cli" / "autocompletion").mkdir(parents=True)
(tmp_path / "vibe" / "cli" / "autocompletion" / "fuzzy.py").write_text("")
(tmp_path / "vibe" / "cli" / "autocompletion" / "completers.py").write_text("")
(tmp_path / "tests" / "autocompletion").mkdir(parents=True)
(tmp_path / "tests" / "autocompletion" / "test_fuzzy.py").write_text("")
(tmp_path / "README.md").write_text("")
monkeypatch.chdir(tmp_path)
return tmp_path
def test_finds_files_recursively_by_filename(file_tree: Path) -> None:
results = PathCompleter().get_completions("@entryp", cursor_pos=7)
assert results[0] == "@vibe/acp/entrypoint.py"
def test_finds_files_recursively_by_partial_path(file_tree: Path) -> None:
results = PathCompleter().get_completions("@acp/entry", cursor_pos=10)
assert results[0] == "@vibe/acp/entrypoint.py"
def test_finds_files_recursively_with_subsequence(file_tree: Path) -> None:
results = PathCompleter().get_completions("@acp/ent", cursor_pos=9)
assert results[0] == "@vibe/acp/entrypoint.py"
def test_finds_multiple_matches_recursively(file_tree: Path) -> None:
results = PathCompleter().get_completions("@fuzzy", cursor_pos=6)
vibe_index = results.index("@vibe/cli/autocompletion/fuzzy.py")
test_index = results.index("@tests/autocompletion/test_fuzzy.py")
assert vibe_index < test_index
def test_prioritizes_exact_path_matches(file_tree: Path) -> None:
results = PathCompleter().get_completions("@vibe/acp/entrypoint", cursor_pos=20)
assert results[0] == "@vibe/acp/entrypoint.py"
def test_finds_files_when_pattern_matches_directory_name(file_tree: Path) -> None:
results = PathCompleter().get_completions("@acp", cursor_pos=4)
assert results == [
"@vibe/acp/",
"@vibe/acp/agent.py",
"@vibe/acp/entrypoint.py",
"@vibe/cli/autocompletion/completers.py",
"@tests/autocompletion/",
"@tests/autocompletion/test_fuzzy.py",
"@vibe/cli/autocompletion/",
"@vibe/cli/autocompletion/fuzzy.py",
]

View File

@@ -0,0 +1,258 @@
from __future__ import annotations
from pathlib import Path
import pytest
from textual import events
from vibe.cli.autocompletion.base import CompletionResult, CompletionView
from vibe.cli.autocompletion.path_completion import PathCompletionController
from vibe.core.autocompletion.completers import PathCompleter
class StubView(CompletionView):
def __init__(self) -> None:
self.suggestions: list[tuple[list[tuple[str, str]], int]] = []
self.clears = 0
self.replacements: list[tuple[int, int, str]] = []
def render_completion_suggestions(
self, suggestions: list[tuple[str, str]], selected_index: int
) -> None:
self.suggestions.append((suggestions, selected_index))
def clear_completion_suggestions(self) -> None:
self.clears += 1
def replace_completion_range(self, start: int, end: int, replacement: str) -> None:
self.replacements.append((start, end, replacement))
@pytest.fixture()
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
(tmp_path / "src" / "utils").mkdir(parents=True)
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core").mkdir(parents=True)
(tmp_path / "src" / "core" / "logger.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "models.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "ports.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "sanitize.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "use_cases.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "core" / "validate.py").write_text("", encoding="utf-8")
(tmp_path / "README.md").write_text("", encoding="utf-8")
(tmp_path / ".env").write_text("", encoding="utf-8")
monkeypatch.chdir(tmp_path)
return tmp_path
def make_controller(
max_entries_to_process: int | None = None, target_matches: int | None = None
) -> tuple[PathCompletionController, StubView]:
completer_kwargs = {}
if max_entries_to_process is not None:
completer_kwargs["max_entries_to_process"] = max_entries_to_process
if target_matches is not None:
completer_kwargs["target_matches"] = target_matches
completer = PathCompleter(**completer_kwargs)
view = StubView()
controller = PathCompletionController(completer, view)
return controller, view
def test_lists_root_entries(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@", cursor_index=1)
suggestions, selected = view.suggestions[-1]
assert selected == 0
assert [alias for alias, _ in suggestions] == ["@README.md", "@src/"]
def test_suggests_hidden_entries_only_with_dot_prefix(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@.", cursor_index=2)
suggestions, _ = view.suggestions[-1]
assert suggestions[0][0] == "@.env"
def test_lists_nested_entries_when_prefixing_with_folder_name(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@src/", cursor_index=5)
suggestions, _ = view.suggestions[-1]
assert [alias for alias, _ in suggestions] == [
"@src/core/",
"@src/main.py",
"@src/utils/",
]
def test_resets_when_fragment_invalid(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@src", cursor_index=4)
assert view.suggestions
controller.on_text_changed("@src foo", cursor_index=8)
assert view.clears == 1
assert (
controller.on_key(events.Key("tab", None), "@src foo", 8)
is CompletionResult.IGNORED
)
def test_applies_selected_completion_on_tab_keycode(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@R", cursor_index=2)
result = controller.on_key(events.Key("tab", None), "@R", 2)
assert result is CompletionResult.HANDLED
assert view.replacements == [(0, 2, "@README.md")]
assert view.clears == 1
def test_applies_selected_completion_on_enter_keycode(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@src/", cursor_index=5)
controller.on_key(events.Key("down", None), "@src/", 5)
result = controller.on_key(events.Key("enter", None), "@src/", 5)
assert result is CompletionResult.HANDLED
assert view.replacements == [(0, 5, "@src/main.py")]
assert view.clears == 1
def test_navigates_and_cycles_across_suggestions(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@src/", cursor_index=5)
controller.on_key(events.Key("down", None), "@src/", 5)
suggestions, selected_index = view.suggestions[-1]
assert [alias for alias, _ in suggestions] == [
"@src/core/",
"@src/main.py",
"@src/utils/",
]
assert selected_index == 1
controller.on_key(events.Key("up", None), "@src/", 5)
suggestions, selected_index = view.suggestions[-1]
assert selected_index == 0
controller.on_key(events.Key("down", None), "@src/", 5)
controller.on_key(events.Key("down", None), "@src/", 5)
suggestions, selected_index = view.suggestions[-1]
assert selected_index == 2
controller.on_key(events.Key("down", None), "@src/", 5)
suggestions, selected_index = view.suggestions[-1]
assert selected_index == 0
def test_limits_suggestions_to_ten(file_tree: Path) -> None:
(file_tree / "src" / "core" / "extra").mkdir(parents=True)
[
(file_tree / "src" / "core" / "extra" / f"extra_file_{i}.py").write_text(
"", encoding="utf-8"
)
for i in range(1, 13)
]
controller, view = make_controller()
controller.on_text_changed("@src/core/extra/", cursor_index=16)
suggestions, selected_index = view.suggestions[-1]
assert len(suggestions) == 10
assert [alias for alias, _ in suggestions] == [
"@src/core/extra/extra_file_1.py",
"@src/core/extra/extra_file_10.py",
"@src/core/extra/extra_file_11.py",
"@src/core/extra/extra_file_12.py",
"@src/core/extra/extra_file_2.py",
"@src/core/extra/extra_file_3.py",
"@src/core/extra/extra_file_4.py",
"@src/core/extra/extra_file_5.py",
"@src/core/extra/extra_file_6.py",
"@src/core/extra/extra_file_7.py",
]
assert selected_index == 0
def test_does_not_handle_when_cursor_at_beginning_of_input(file_tree: Path) -> None:
controller, _ = make_controller()
assert not controller.can_handle("@file", cursor_index=0)
assert not controller.can_handle("", cursor_index=0)
assert not controller.can_handle("hello@file", cursor_index=0)
def test_does_not_handle_when_cursor_before_or_at_the_at_symbol(
file_tree: Path,
) -> None:
controller, _ = make_controller()
assert not controller.can_handle("@file", cursor_index=0)
assert not controller.can_handle("hello@file", cursor_index=5)
def test_does_handle_when_cursor_after_the_at_symbol_even_in_the_middle_of_the_input(
file_tree: Path,
) -> None:
controller, _ = make_controller()
assert controller.can_handle("@file", cursor_index=1)
assert controller.can_handle("hello @file", cursor_index=7)
def test_lists_immediate_children_when_path_ends_with_slash(file_tree: Path) -> None:
controller, view = make_controller()
controller.on_text_changed("@src/", cursor_index=5)
suggestions, _ = view.suggestions[-1]
assert [alias for alias, _ in suggestions] == [
"@src/core/",
"@src/main.py",
"@src/utils/",
]
def test_respects_max_entries_to_process_limit(file_tree: Path) -> None:
for i in range(30):
(file_tree / f"file_{i:03d}.txt").write_text("", encoding="utf-8")
controller, view = make_controller(max_entries_to_process=10)
controller.on_text_changed("@", cursor_index=1)
suggestions, _ = view.suggestions[-1]
assert len(suggestions) <= 10
def test_respects_target_matches_limit_for_listing(file_tree: Path) -> None:
for i in range(30):
(file_tree / f"item_{i:03d}.txt").write_text("", encoding="utf-8")
controller, view = make_controller(target_matches=5)
controller.on_text_changed("@", cursor_index=1)
suggestions, _ = view.suggestions[-1]
assert len(suggestions) <= 5
def test_respects_target_matches_limit_for_fuzzy_search(file_tree: Path) -> None:
for i in range(30):
(file_tree / f"test_file_{i:03d}.py").write_text("", encoding="utf-8")
controller, view = make_controller(target_matches=5)
controller.on_text_changed("@test", cursor_index=5)
suggestions, _ = view.suggestions[-1]
assert len(suggestions) <= 5

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from pathlib import Path
from vibe.core.autocompletion.path_prompt_adapter import (
DEFAULT_MAX_EMBED_BYTES,
render_path_prompt,
)
def test_treats_paths_to_files_as_embedded_resources(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text("hello", encoding="utf-8")
src_dir = tmp_path / "src"
src_dir.mkdir()
main_py = src_dir / "main.py"
main_py.write_text("print('hi')", encoding="utf-8")
rendered = render_path_prompt(
"Please review @README.md and @src/main.py",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
expected = (
f"Please review README.md and src/main.py\n\n"
f"{readme.as_uri()}\n```\nhello\n```\n\n"
f"{main_py.as_uri()}\n```\nprint('hi')\n```"
)
assert rendered == expected
def test_treats_path_to_directory_as_resource_links(tmp_path: Path) -> None:
docs_dir = tmp_path / "docs"
docs_dir.mkdir()
rendered = render_path_prompt(
"See @docs/ for details",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
expected = f"See docs/ for details\n\nuri: {docs_dir.as_uri()}\nname: docs/"
assert rendered == expected
def test_keeps_emails_and_embeds_paths(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text("hello", encoding="utf-8")
rendered = render_path_prompt(
"Contact user@example.com about @README.md",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
expected = (
f"Contact user@example.com about README.md\n\n"
f"{readme.as_uri()}\n```\nhello\n```"
)
assert rendered == expected
def test_ignores_nonexistent_paths(tmp_path: Path) -> None:
rendered = render_path_prompt(
"Missing @nope.txt here",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
assert rendered == "Missing @nope.txt here"
def test_falls_back_to_link_for_binary_files(tmp_path: Path) -> None:
binary_path = tmp_path / "image.bin"
binary_path.write_bytes(b"\x00\x01\x02")
rendered = render_path_prompt(
"Inspect @image.bin", base_dir=tmp_path, max_embed_bytes=DEFAULT_MAX_EMBED_BYTES
)
assert (
rendered == f"Inspect image.bin\n\nuri: {binary_path.as_uri()}\nname: image.bin"
)
def test_excludes_supposed_binary_files_quickly_before_reading_content(
tmp_path: Path,
) -> None:
zip_like = tmp_path / "archive.zip"
zip_like.write_text("text content inside but treated as binary", encoding="utf-8")
rendered = render_path_prompt(
"Inspect @archive.zip",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
assert (
rendered
== f"Inspect archive.zip\n\nuri: {zip_like.as_uri()}\nname: archive.zip"
)
def test_applies_max_embed_size_guard(tmp_path: Path) -> None:
large_file = tmp_path / "big.txt"
large_file.write_text("a" * 50, encoding="utf-8")
rendered = render_path_prompt(
"Review @big.txt", base_dir=tmp_path, max_embed_bytes=10
)
assert rendered == f"Review big.txt\n\nuri: {large_file.as_uri()}\nname: big.txt"
def test_parses_paths_with_special_characters_when_quoted(tmp_path: Path) -> None:
weird = tmp_path / "weird name(1).txt"
weird.write_text("odd", encoding="utf-8")
rendered = render_path_prompt(
'Open @"weird name(1).txt"',
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
assert rendered == f"Open weird name(1).txt\n\n{weird.as_uri()}\n```\nodd\n```"
def test_deduplicates_identical_paths(tmp_path: Path) -> None:
readme = tmp_path / "README.md"
readme.write_text("hello", encoding="utf-8")
rendered = render_path_prompt(
"See @README.md and again @README.md",
base_dir=tmp_path,
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
)
assert (
rendered
== f"See README.md and again README.md\n\n{readme.as_uri()}\n```\nhello\n```"
)

View File

@@ -0,0 +1,162 @@
from __future__ import annotations
from typing import NamedTuple
from textual import events
from vibe.cli.autocompletion.base import CompletionResult, CompletionView
from vibe.cli.autocompletion.slash_command import SlashCommandController
from vibe.core.autocompletion.completers import CommandCompleter
class Suggestion(NamedTuple):
alias: str
description: str
class SuggestionEvent(NamedTuple):
suggestions: list[Suggestion]
selected_index: int
class Replacement(NamedTuple):
start: int
end: int
replacement: str
class StubView(CompletionView):
def __init__(self) -> None:
self.suggestion_events: list[SuggestionEvent] = []
self.reset_count = 0
self.replacements: list[Replacement] = []
def render_completion_suggestions(
self, suggestions: list[tuple[str, str]], selected_index: int
) -> None:
typed = [Suggestion(alias, description) for alias, description in suggestions]
self.suggestion_events.append(SuggestionEvent(typed, selected_index))
def clear_completion_suggestions(self) -> None:
self.reset_count += 1
def replace_completion_range(self, start: int, end: int, replacement: str) -> None:
self.replacements.append(Replacement(start, end, replacement))
def key_event(key: str) -> events.Key:
return events.Key(key, character=None)
def make_controller(
*, prefix: str | None = None
) -> tuple[SlashCommandController, StubView]:
commands = [
("/config", "Show current configuration"),
("/compact", "Compact history"),
("/help", "Display help"),
("/config", "Override description"),
("/summarize", "Summarize history"),
("/logpath", "Show log path"),
("/exit", "Exit application"),
("/vim", "Toggle vim keybindings"),
]
completer = CommandCompleter(commands)
view = StubView()
controller = SlashCommandController(completer, view)
if prefix is not None:
controller.on_text_changed(prefix, cursor_index=len(prefix))
view.suggestion_events.clear()
return controller, view
def test_on_text_change_emits_matching_suggestions_in_insertion_order_and_ignores_duplicates() -> (
None
):
controller, view = make_controller(prefix="/c")
controller.on_text_changed("/c", cursor_index=2)
suggestions, selected = view.suggestion_events[-1]
assert suggestions == [
Suggestion("/config", "Override description"),
Suggestion("/compact", "Compact history"),
]
assert selected == 0
def test_on_text_change_filters_suggestions_case_insensitively() -> None:
controller, view = make_controller(prefix="/c")
controller.on_text_changed("/CO", cursor_index=3)
suggestions, _ = view.suggestion_events[-1]
assert [suggestion.alias for suggestion in suggestions] == ["/config", "/compact"]
def test_on_text_change_clears_suggestions_when_no_matches() -> None:
controller, view = make_controller(prefix="/c")
controller.on_text_changed("/c", cursor_index=2)
controller.on_text_changed("config", cursor_index=6)
assert view.reset_count >= 1
def test_on_text_change_limits_the_number_of_results_to_five_and_preserve_insertion_order() -> (
None
):
controller, view = make_controller(prefix="/")
controller.on_text_changed("/", cursor_index=1)
suggestions, selected_index = view.suggestion_events[-1]
assert len(suggestions) == 5
assert [suggestion.alias for suggestion in suggestions] == [
"/config",
"/compact",
"/help",
"/summarize",
"/logpath",
]
def test_on_key_tab_applies_selected_completion() -> None:
controller, view = make_controller(prefix="/c")
result = controller.on_key(key_event("tab"), text="/c", cursor_index=2)
assert result is CompletionResult.HANDLED
assert view.replacements == [Replacement(0, 2, "/config")]
assert view.reset_count == 1
def test_on_key_down_and_up_cycle_selection() -> None:
controller, view = make_controller(prefix="/c")
controller.on_key(key_event("down"), text="/c", cursor_index=2)
suggestions, selected_index = view.suggestion_events[-1]
assert selected_index == 1
controller.on_key(key_event("down"), text="/c", cursor_index=2)
suggestions, selected_index = view.suggestion_events[-1]
assert selected_index == 0
controller.on_key(key_event("up"), text="/c", cursor_index=2)
suggestions, selected_index = view.suggestion_events[-1]
assert selected_index == 1
assert [suggestion.alias for suggestion in suggestions] == ["/config", "/compact"]
def test_on_key_enter_submits_selected_completion() -> None:
controller, view = make_controller(prefix="/c")
controller.on_key(key_event("down"), text="/c", cursor_index=2)
result = controller.on_key(key_event("enter"), text="/c", cursor_index=2)
assert result is CompletionResult.SUBMIT
assert view.replacements == [Replacement(0, 2, "/compact")]
assert view.reset_count == 1

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
from pathlib import Path
import pytest
from textual.content import Content
from textual.style import Style
from textual.widgets import Markdown
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.textual_ui.widgets.chat_input.completion_popup import CompletionPopup
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
from vibe.core.config import SessionLoggingConfig, VibeConfig
@pytest.fixture
def vibe_config() -> VibeConfig:
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
@pytest.fixture
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
return VibeApp(config=vibe_config)
@pytest.mark.asyncio
async def test_popup_appears_with_matching_suggestions(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/sum")
popup_content = str(popup.render())
assert popup.styles.display == "block"
assert "/summarize" in popup_content
assert "Compact conversation history by summarizing" in popup_content
assert chat_input.value == "/sum"
@pytest.mark.asyncio
async def test_popup_hides_when_input_cleared(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/c")
await pilot.press("backspace", "backspace")
assert popup.styles.display == "none"
@pytest.mark.asyncio
async def test_pressing_tab_writes_selected_command_and_keeps_popup_visible(
vibe_app: VibeApp,
) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/co")
await pilot.press("tab")
assert chat_input.value == "/config"
assert popup.styles.display == "block"
def ensure_selected_command(popup: CompletionPopup, expected_alias: str) -> None:
renderable = popup.render()
assert isinstance(renderable, Content)
content = str(renderable)
selected_aliases: list[str] = []
for span in renderable.spans:
style = span.style
if isinstance(style, Style) and style.reverse:
alias_text = content[span.start : span.end].strip()
alias = alias_text.split()[0] if alias_text else ""
selected_aliases.append(alias)
assert len(selected_aliases) == 1
assert selected_aliases[0] == expected_alias
@pytest.mark.asyncio
async def test_arrow_navigation_updates_selected_suggestion(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/c")
ensure_selected_command(popup, "/cfg")
await pilot.press("down")
ensure_selected_command(popup, "/config")
await pilot.press("up")
ensure_selected_command(popup, "/cfg")
@pytest.mark.asyncio
async def test_arrow_navigation_cycles_through_suggestions(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/st")
ensure_selected_command(popup, "/stats")
await pilot.press("down")
ensure_selected_command(popup, "/status")
await pilot.press("up")
ensure_selected_command(popup, "/stats")
@pytest.mark.asyncio
async def test_pressing_enter_submits_selected_command_and_hides_popup(
vibe_app: VibeApp,
) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"/hel") # typos:disable-line
await pilot.press("enter")
assert chat_input.value == ""
assert popup.styles.display == "none"
message = vibe_app.query_one(".user-command-message")
message_content = message.query_one(Markdown)
assert "Show help message" in message_content.source
@pytest.fixture()
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
(tmp_path / "src" / "utils").mkdir(parents=True)
(tmp_path / "src" / "utils" / "config.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "utils" / "database.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "utils" / "error_handling.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "utils" / "logger.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "utils" / "sanitize.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "utils" / "validate.py").write_text("", encoding="utf-8")
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
(tmp_path / "vibe" / "acp").mkdir(parents=True)
(tmp_path / "vibe" / "acp" / "entrypoint.py").write_text("", encoding="utf-8")
(tmp_path / "vibe" / "acp" / "agent.py").write_text("", encoding="utf-8")
(tmp_path / "README.md").write_text("", encoding="utf-8")
(tmp_path / ".env").write_text("", encoding="utf-8")
monkeypatch.chdir(tmp_path)
return tmp_path
@pytest.mark.asyncio
async def test_path_completion_popup_lists_files_and_directories(
vibe_app: VibeApp, file_tree: Path
) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@s")
popup_content = str(popup.render())
assert "@src/" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_path_completion_popup_shows_up_to_ten_results(
vibe_app: VibeApp, file_tree: Path
) -> None:
async with vibe_app.run_test() as pilot:
(file_tree / "src" / "core" / "extra").mkdir(parents=True)
[
(file_tree / "src" / "core" / "extra" / f"extra_file_{i}.py").write_text(
"", encoding="utf-8"
)
for i in range(1, 13)
]
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@src/core/extra/")
popup_content = str(popup.render())
assert "@src/core/extra/extra_file_1.py" in popup_content
assert "@src/core/extra/extra_file_10.py" in popup_content
assert "@src/core/extra/extra_file_11.py" in popup_content
assert "@src/core/extra/extra_file_12.py" in popup_content
assert "@src/core/extra/extra_file_2.py" in popup_content
assert "@src/core/extra/extra_file_3.py" in popup_content
assert "@src/core/extra/extra_file_4.py" in popup_content
assert "@src/core/extra/extra_file_5.py" in popup_content
assert "@src/core/extra/extra_file_6.py" in popup_content
assert "@src/core/extra/extra_file_7.py" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_pressing_tab_writes_selected_path_name_and_hides_popup(
vibe_app: VibeApp, file_tree: Path
) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"Print @REA")
await pilot.press("tab")
assert chat_input.value == "Print @README.md "
assert popup.styles.display == "none"
@pytest.mark.asyncio
async def test_pressing_enter_writes_selected_path_name_and_hides_popup(
vibe_app: VibeApp, file_tree: Path
) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"Print @src/m")
await pilot.press("enter")
assert chat_input.value == "Print @src/main.py "
assert popup.styles.display == "none"
@pytest.mark.asyncio
async def test_fuzzy_matches_subsequence_characters(
file_tree: Path, vibe_app: VibeApp
) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@src/utils/handling")
popup_content = str(popup.render())
assert "@src/utils/error_handling.py" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_fuzzy_matches_word_boundaries(
file_tree: Path, vibe_app: VibeApp
) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@src/utils/eh")
popup_content = str(popup.render())
assert "@src/utils/error_handling.py" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_finds_files_recursively_by_filename(
file_tree: Path, vibe_app: VibeApp
) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@entryp")
popup_content = str(popup.render())
assert "@vibe/acp/entrypoint.py" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_finds_files_recursively_with_partial_path(
file_tree: Path, vibe_app: VibeApp
) -> None:
async with vibe_app.run_test() as pilot:
popup = vibe_app.query_one(CompletionPopup)
await pilot.press(*"@acp/entry")
popup_content = str(popup.render())
assert "@vibe/acp/entrypoint.py" in popup_content
assert popup.styles.display == "block"
@pytest.mark.asyncio
async def test_does_not_trigger_completion_when_navigating_history(
file_tree: Path, vibe_app: VibeApp
) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
popup = vibe_app.query_one(CompletionPopup)
message_with_path = "Check @src/m"
message_to_fill_history = "Yet another message to fill history"
await pilot.press(*message_with_path)
await pilot.press("tab", "enter")
await pilot.press(*message_to_fill_history)
await pilot.press("enter")
await pilot.press("up", "up")
assert chat_input.value == "Check @src/main.py"
await pilot.pause(0.2)
# ensure popup is hidden - user was navigating history: we don't want to interrupt
assert popup.styles.display == "none"
await pilot.press("down")
await pilot.pause(0.1)
assert popup.styles.display == "none"
# get back to the message with path completion; ensure again
await pilot.press("up")
await pilot.pause(0.1)
assert chat_input.value == "Check @src/main.py"
await pilot.pause(0.2)
assert popup.styles.display == "none"

View File

View File

@@ -0,0 +1,6 @@
from __future__ import annotations
Url = str
JsonResponse = dict
ResultData = dict
Chunk = bytes

View File

@@ -0,0 +1,183 @@
from __future__ import annotations
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
(
"https://api.fireworks.ai",
{
"id": "fake_id_1234",
"object": "chat.completion",
"created": 1234567890,
"model": "accounts/fireworks/models/glm-4p5",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Some content",
"reasoning_content": "Some reasoning content",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 100,
"total_tokens": 300,
"completion_tokens": 200,
"prompt_tokens_details": {"cached_tokens": 0},
},
},
{
"message": "Some content",
"finish_reason": "stop",
"usage": {
"prompt_tokens": 100,
"total_tokens": 300,
"completion_tokens": 200,
},
},
)
]
TOOL_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
(
"https://api.fireworks.ai",
{
"id": "fake_id_1234",
"object": "chat.completion",
"created": 1234567890,
"model": "accounts/fireworks/models/glm-4p5",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"reasoning_content": "Some reasoning content",
"tool_calls": [
{
"index": 0,
"id": "fake_id_5678",
"type": "function",
"function": {
"name": "some_tool",
"arguments": '{"some_argument": "some_argument_value"}',
},
"name": None,
}
],
},
"finish_reason": "tool_calls",
}
],
"usage": {
"prompt_tokens": 100,
"completion_tokens": 200,
"prompt_tokens_details": {"cached_tokens": 0},
},
},
{
"message": "",
"finish_reason": "tool_calls",
"tool_calls": [
{
"name": "some_tool",
"arguments": '{"some_argument": "some_argument_value"}',
"index": 0,
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
)
]
STREAMED_SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
(
"https://api.fireworks.ai",
[
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}],"usage":null}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"reasoning_content":"Some reasoning content"},"finish_reason":null}],"usage":null}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"usage":null}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200,"prompt_tokens_details":{"cached_tokens":0}}}',
rb"data: [DONE]",
],
[
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "Some content",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": "stop",
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
],
)
]
STREAMED_TOOL_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
(
"https://api.fireworks.ai",
[
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}],"usage": null}',
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"reasoning_content": "Some reasoning content"},"finish_reason": null}],"usage": null}',
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"content": "Some content"},"finish_reason": null}],"usage": null}',
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"tool_calls": [{"index": 0,"id": "fake_id_151617","type": "function","function": {"name": "some_tool"}}]},"finish_reason": null}],"usage": null}',
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"tool_calls": [{"index": 0,"id": null,"type": "function","function": {"arguments": "{\"some_argument\": \"some_arguments_value\"}"}}]},"finish_reason": null}],"usage": null}',
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}],"usage": {"prompt_tokens": 100,"total_tokens": 300,"completion_tokens": 200,"prompt_tokens_details": {"cached_tokens": 190}}}',
rb"data: [DONE]",
],
[
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "Some content",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"tool_calls": [{"name": "some_tool", "arguments": None, "index": 0}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"tool_calls": [
{
"name": None,
"arguments": '{"some_argument": "some_arguments_value"}',
"index": 0,
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": "tool_calls",
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
],
)
]

View File

@@ -0,0 +1,173 @@
from __future__ import annotations
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
(
"https://api.mistral.ai",
{
"id": "fake_id_1234",
"created": 1234567890,
"model": "devstral-latest",
"usage": {
"prompt_tokens": 100,
"total_tokens": 300,
"completion_tokens": 200,
},
"object": "chat.completion",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"tool_calls": None,
"content": "Some content",
},
}
],
},
{
"message": "Some content",
"finish_reason": "stop",
"usage": {
"prompt_tokens": 100,
"total_tokens": 300,
"completion_tokens": 200,
},
},
)
]
TOOL_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
(
"https://api.mistral.ai",
{
"id": "fake_id_1234",
"created": 1234567890,
"model": "devstral-latest",
"usage": {
"prompt_tokens": 100,
"total_tokens": 300,
"completion_tokens": 200,
},
"object": "chat.completion",
"choices": [
{
"index": 0,
"finish_reason": "tool_calls",
"message": {
"role": "assistant",
"tool_calls": [
{
"id": "fake_id_5678",
"function": {
"name": "some_tool",
"arguments": '{"some_argument": "some_argument_value"}',
},
"index": 0,
}
],
"content": "Some content",
},
}
],
},
{
"message": "Some content",
"finish_reason": "tool_calls",
"tool_calls": [
{
"name": "some_tool",
"arguments": '{"some_argument": "some_argument_value"}',
"index": 0,
}
],
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
)
]
STREAMED_SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
(
"https://api.mistral.ai",
[
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"p":"abcde"}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200},"p":"abcdefghijklmnopq"}',
rb"data: [DONE]",
],
[
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "Some content",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": "stop",
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
],
)
]
STREAMED_TOOL_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
(
"https://api.mistral.ai",
[
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"p":"a"}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"id":"fake_id_1234","function":{"name":"some_tool","arguments":""},"index":0}]},"finish_reason":null}],"p":"abcdef"}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"function":{"name":"","arguments":"{\"some_argument\": "},"index":0}]},"finish_reason":null}],"p":"abcdefghijklmnopq"}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"id":"null","function":{"name":"","arguments":"\"some_argument_value\"}"},"index":0}]},"finish_reason":null}],"p":"abcdefghijklmnopqrstuvwxyz0123456"}',
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200},"p":"abcdefghijklmnopq"}',
rb"data: [DONE]",
],
[
{
"message": "",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "Some content",
"finish_reason": None,
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"tool_calls": [{"name": "some_tool", "arguments": "", "index": 0}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"tool_calls": [
{"name": "", "arguments": '{"some_argument": ', "index": 0}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": None,
"tool_calls": [
{"name": "", "arguments": '"some_argument_value"}', "index": 0}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
},
{
"message": "",
"finish_reason": "tool_calls",
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
},
],
)
]

View File

@@ -0,0 +1,248 @@
"""Test data for this module was generated using real LLM provider API responses,
with responses simplified and formatted to make them readable and maintainable.
To update or modify test parameters:
1. Make actual API calls to the target providers
2. Use the raw API responses as a base for updating test data
3. Simplify only where necessary for readability while preserving core structure
The closer test data remains to real API responses, the more reliable and accurate
the tests will be. Always prefer real API data over manually constructed examples.
"""
from __future__ import annotations
import httpx
import pytest
import respx
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
from tests.backend.data.fireworks import (
SIMPLE_CONVERSATION_PARAMS as FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
STREAMED_SIMPLE_CONVERSATION_PARAMS as FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
STREAMED_TOOL_CONVERSATION_PARAMS as FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
TOOL_CONVERSATION_PARAMS as FIREWORKS_TOOL_CONVERSATION_PARAMS,
)
from tests.backend.data.mistral import (
SIMPLE_CONVERSATION_PARAMS as MISTRAL_SIMPLE_CONVERSATION_PARAMS,
STREAMED_SIMPLE_CONVERSATION_PARAMS as MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
STREAMED_TOOL_CONVERSATION_PARAMS as MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
TOOL_CONVERSATION_PARAMS as MISTRAL_TOOL_CONVERSATION_PARAMS,
)
from vibe.core.config import ModelConfig, ProviderConfig
from vibe.core.llm.backend.generic import GenericBackend
from vibe.core.llm.backend.mistral import MistralBackend
from vibe.core.llm.exceptions import BackendError
from vibe.core.llm.types import BackendLike
from vibe.core.types import LLMChunk, LLMMessage, Role, ToolCall
class TestBackend:
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,json_response,result_data",
[
*FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
*FIREWORKS_TOOL_CONVERSATION_PARAMS,
*MISTRAL_SIMPLE_CONVERSATION_PARAMS,
*MISTRAL_TOOL_CONVERSATION_PARAMS,
],
)
async def test_backend_complete(
self, base_url: Url, json_response: JsonResponse, result_data: ResultData
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(
return_value=httpx.Response(status_code=200, json=json_response)
)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
BackendClasses = [
GenericBackend,
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
]
for BackendClass in BackendClasses:
backend: BackendLike = BackendClass(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [LLMMessage(role=Role.user, content="Just say hi")]
result = await backend.complete(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
)
assert result.message.content == result_data["message"]
assert result.finish_reason == result_data["finish_reason"]
assert result.usage is not None
assert (
result.usage.prompt_tokens == result_data["usage"]["prompt_tokens"]
)
assert (
result.usage.completion_tokens
== result_data["usage"]["completion_tokens"]
)
if result.message.tool_calls is None:
return
assert len(result.message.tool_calls) == len(result_data["tool_calls"])
for i, tool_call in enumerate[ToolCall](result.message.tool_calls):
assert (
tool_call.function.name == result_data["tool_calls"][i]["name"]
)
assert (
tool_call.function.arguments
== result_data["tool_calls"][i]["arguments"]
)
assert tool_call.index == result_data["tool_calls"][i]["index"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,chunks,result_data",
[
*FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
*FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
*MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
*MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
],
)
async def test_backend_complete_streaming(
self, base_url: Url, chunks: list[Chunk], result_data: list[ResultData]
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(
return_value=httpx.Response(
status_code=200,
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
headers={"Content-Type": "text/event-stream"},
)
)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
BackendClasses = [
GenericBackend,
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
]
for BackendClass in BackendClasses:
backend: BackendLike = BackendClass(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [
LLMMessage(role=Role.user, content="List files in current dir")
]
results: list[LLMChunk] = []
async for result in backend.complete_streaming(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
):
results.append(result)
for result, expected_result in zip(results, result_data, strict=True):
assert result.message.content == expected_result["message"]
assert result.finish_reason == expected_result["finish_reason"]
assert result.usage is not None
assert (
result.usage.prompt_tokens
== expected_result["usage"]["prompt_tokens"]
)
assert (
result.usage.completion_tokens
== expected_result["usage"]["completion_tokens"]
)
if result.message.tool_calls is None:
continue
for i, tool_call in enumerate(result.message.tool_calls):
assert (
tool_call.function.name
== expected_result["tool_calls"][i]["name"]
)
assert (
tool_call.function.arguments
== expected_result["tool_calls"][i]["arguments"]
)
assert (
tool_call.index == expected_result["tool_calls"][i]["index"]
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"base_url,backend_class,response",
[
(
"https://api.fireworks.ai",
GenericBackend,
httpx.Response(status_code=500, text="Internal Server Error"),
),
(
"https://api.fireworks.ai",
GenericBackend,
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
),
(
"https://api.mistral.ai",
MistralBackend,
httpx.Response(status_code=500, text="Internal Server Error"),
),
(
"https://api.mistral.ai",
MistralBackend,
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
),
],
)
async def test_backend_complete_streaming_error(
self,
base_url: Url,
backend_class: type[MistralBackend | GenericBackend],
response: httpx.Response,
):
with respx.mock(base_url=base_url) as mock_api:
mock_api.post("/v1/chat/completions").mock(return_value=response)
provider = ProviderConfig(
name="provider_name",
api_base=f"{base_url}/v1",
api_key_env_var="API_KEY",
)
backend = backend_class(provider=provider)
model = ModelConfig(
name="model_name", provider="provider_name", alias="model_alias"
)
messages = [LLMMessage(role=Role.user, content="Just say hi")]
with pytest.raises(BackendError) as e:
async for _ in backend.complete_streaming(
model=model,
messages=messages,
temperature=0.2,
tools=None,
max_tokens=None,
tool_choice=None,
extra_headers=None,
):
pass
assert e.value.status == response.status_code
assert e.value.reason == response.reason_phrase
assert e.value.parsed_error is None

87
tests/conftest.py Normal file
View File

@@ -0,0 +1,87 @@
from __future__ import annotations
import sys
from typing import Any
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
import pytest
_in_mem_config: dict[str, Any] = {}
class InMemSettingsSource(PydanticBaseSettingsSource):
def __init__(self, settings_cls: type[BaseSettings]) -> None:
super().__init__(settings_cls)
def get_field_value(
self, field: FieldInfo, field_name: str
) -> tuple[Any, str, bool]:
return _in_mem_config.get(field_name), field_name, False
def __call__(self) -> dict[str, Any]:
return _in_mem_config
@pytest.fixture(autouse=True, scope="session")
def _patch_vibe_config() -> None:
"""Patch VibeConfig.settings_customise_sources to only use init_settings in tests.
This ensures that even production code that creates VibeConfig instances
will only use init_settings and ignore environment variables and config files.
Runs once per test session before any tests execute.
"""
from vibe.core.config import VibeConfig
def patched_settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings, InMemSettingsSource(settings_cls))
VibeConfig.settings_customise_sources = classmethod(
patched_settings_customise_sources
) # type: ignore[assignment]
def dump_config(cls, config: dict[str, Any]) -> None:
global _in_mem_config
_in_mem_config = config
VibeConfig.dump_config = classmethod(dump_config) # type: ignore[assignment]
def patched_load(cls, agent: str | None = None, **overrides: Any) -> Any:
return cls(**overrides)
VibeConfig.load = classmethod(patched_load) # type: ignore[assignment]
@pytest.fixture(autouse=True)
def _reset_in_mem_config() -> None:
"""Reset in-memory config before each test to prevent test isolation issues.
This ensures that each test starts with a clean configuration state,
preventing race conditions and test interference when tests run in parallel
or when VibeConfig.save_updates() modifies the shared _in_mem_config dict.
"""
global _in_mem_config
_in_mem_config = {}
@pytest.fixture(autouse=True)
def _mock_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("MISTRAL_API_KEY", "mock")
@pytest.fixture(autouse=True)
def _mock_platform(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock platform to be Linux with /bin/sh shell for consistent test behavior.
This ensures that platform-specific system prompt generation is consistent
across all tests regardless of the actual platform running the tests.
"""
monkeypatch.setattr(sys, "platform", "linux")
monkeypatch.setenv("SHELL", "/bin/sh")

View File

@@ -0,0 +1,51 @@
from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
import tomllib
import tomli_w
from vibe.core import config
from vibe.core.config import VibeConfig
def _restore_dump_config(config_file: Path):
original_dump_config = VibeConfig.dump_config
def real_dump_config(cls, config_dict: dict) -> None:
try:
with config_file.open("wb") as f:
tomli_w.dump(config_dict, f)
except OSError:
config_file.write_text(
"\n".join(
f"{k} = {v!r}" for k, v in config_dict.items() if v is not None
),
encoding="utf-8",
)
VibeConfig.dump_config = classmethod(real_dump_config) # type: ignore[assignment]
return original_dump_config
@contextmanager
def _migrate_config_file(tmp_path: Path, content: str):
config_file = tmp_path / "config.toml"
config_file.write_text(content, encoding="utf-8")
original_config_file = config.CONFIG_FILE
original_dump_config = _restore_dump_config(config_file)
try:
config.CONFIG_FILE = config_file
VibeConfig._migrate()
yield config_file
finally:
config.CONFIG_FILE = original_config_file
VibeConfig.dump_config = original_dump_config
def _load_migrated_config(config_file: Path) -> dict:
with config_file.open("rb") as f:
return tomllib.load(f)

0
tests/mock/__init__.py Normal file
View File

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
from contextlib import contextmanager
from vibe.core.config import Backend
from vibe.core.llm.backend.factory import BACKEND_FACTORY
@contextmanager
def mock_backend_factory(backend_type: Backend, factory_func):
original = BACKEND_FACTORY[backend_type]
try:
BACKEND_FACTORY[backend_type] = factory_func
yield
finally:
BACKEND_FACTORY[backend_type] = original

View File

@@ -0,0 +1,66 @@
"""Wrapper script that intercepts LLM calls when mocking is enabled.
This script is used to mock the LLM calls when testing the CLI.
Mocked returns are stored in the VIBE_MOCK_LLM_DATA environment variable.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
import json
import os
import sys
from unittest.mock import patch
from pydantic import ValidationError
from tests import TESTS_ROOT
from tests.mock.utils import MOCK_DATA_ENV_VAR
from vibe.core.types import LLMChunk
def mock_llm_output() -> None:
sys.path.insert(0, str(TESTS_ROOT))
# Apply mocking before importing any vibe modules
mock_data_str = os.environ.get(MOCK_DATA_ENV_VAR)
if not mock_data_str:
raise ValueError(f"{MOCK_DATA_ENV_VAR} is not set")
mock_data = json.loads(mock_data_str)
try:
chunks = [LLMChunk.model_validate(chunk) for chunk in mock_data]
except ValidationError as e:
raise ValueError(f"Invalid mock data: {e}") from e
chunk_iterable = iter(chunks)
async def mock_complete(*args, **kwargs) -> LLMChunk:
return next(chunk_iterable)
async def mock_complete_streaming(*args, **kwargs) -> AsyncGenerator[LLMChunk]:
yield next(chunk_iterable)
patch(
"vibe.core.llm.backend.mistral.MistralBackend.complete",
side_effect=mock_complete,
).start()
patch(
"vibe.core.llm.backend.generic.GenericBackend.complete",
side_effect=mock_complete,
).start()
patch(
"vibe.core.llm.backend.mistral.MistralBackend.complete_streaming",
side_effect=mock_complete_streaming,
).start()
patch(
"vibe.core.llm.backend.generic.GenericBackend.complete_streaming",
side_effect=mock_complete_streaming,
).start()
if __name__ == "__main__":
mock_llm_output()
from vibe.acp.entrypoint import main
main()

42
tests/mock/utils.py Normal file
View File

@@ -0,0 +1,42 @@
from __future__ import annotations
import json
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role, ToolCall
MOCK_DATA_ENV_VAR = "VIBE_MOCK_LLM_DATA"
def mock_llm_chunk(
content: str = "Hello!",
role: Role = Role.assistant,
tool_calls: list[ToolCall] | None = None,
name: str | None = None,
tool_call_id: str | None = None,
finish_reason: str | None = None,
prompt_tokens: int = 10,
completion_tokens: int = 5,
) -> LLMChunk:
message = LLMMessage(
role=role,
content=content,
tool_calls=tool_calls,
name=name,
tool_call_id=tool_call_id,
)
return LLMChunk(
message=message,
usage=LLMUsage(
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
),
finish_reason=finish_reason,
)
def get_mocking_env(mock_chunks: list[LLMChunk] | None = None) -> dict[str, str]:
if mock_chunks is None:
mock_chunks = [mock_llm_chunk()]
mock_data = [LLMChunk.model_dump(mock_chunk) for mock_chunk in mock_chunks]
return {MOCK_DATA_ENV_VAR: json.dumps(mock_data)}

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from pathlib import Path
import sys
from typing import override
import pytest
from textual.app import App
from vibe.setup import onboarding
class StubApp(App[str | None]):
def __init__(self, return_value: str | None) -> None:
super().__init__()
self._return_value = return_value
@override
def run(self, *args: object, **kwargs: object) -> str | None:
return self._return_value
def _patch_env_file(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
env_file = tmp_path / ".env"
monkeypatch.setattr(onboarding, "GLOBAL_ENV_FILE", env_file, raising=False)
return env_file
def _exit_raiser(code: int = 0) -> None:
raise SystemExit(code)
def test_exits_on_cancel(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
) -> None:
_patch_env_file(monkeypatch, tmp_path)
monkeypatch.setattr(sys, "exit", _exit_raiser)
with pytest.raises(SystemExit) as excinfo:
onboarding.run_onboarding(StubApp(None))
assert excinfo.value.code == 0
out = capsys.readouterr().out
assert "Setup cancelled. See you next time!" in out
def test_warns_on_save_error(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
) -> None:
_patch_env_file(monkeypatch, tmp_path)
monkeypatch.setattr(sys, "exit", _exit_raiser)
onboarding.run_onboarding(StubApp("save_error:disk full"))
out = capsys.readouterr().out
assert "Could not save API key" in out
assert "disk full" in out
def test_successfully_completes(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
) -> None:
_patch_env_file(monkeypatch, tmp_path)
monkeypatch.setattr(sys, "exit", _exit_raiser)
onboarding.run_onboarding(StubApp("completed"))
out = capsys.readouterr().out
assert out == ""

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
from typing import Any
import pytest
from textual.events import Resize
from textual.geometry import Size
from textual.pilot import Pilot
from textual.widgets import Input
from vibe.core import config as core_config
from vibe.setup.onboarding import OnboardingApp
import vibe.setup.onboarding.screens.api_key as api_key_module
from vibe.setup.onboarding.screens.api_key import ApiKeyScreen
from vibe.setup.onboarding.screens.theme_selection import THEMES, ThemeSelectionScreen
async def _wait_for(
condition: Callable[[], bool],
pilot: Pilot,
timeout: float = 5.0,
interval: float = 0.05,
) -> None:
elapsed = 0.0
while not condition():
await pilot.pause(interval)
if (elapsed := elapsed + interval) >= timeout:
msg = "Timed out waiting for condition."
raise AssertionError(msg)
@pytest.fixture()
def onboarding_app(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> tuple[OnboardingApp, Path, dict[str, Any]]:
vibe_home = tmp_path / ".vibe"
env_file = vibe_home / ".env"
saved_updates: dict[str, Any] = {}
def record_updates(updates: dict[str, Any]) -> None:
saved_updates.update(updates)
monkeypatch.setenv("VIBE_HOME", str(vibe_home))
for module in (core_config, api_key_module):
monkeypatch.setattr(module, "GLOBAL_CONFIG_DIR", vibe_home, raising=False)
monkeypatch.setattr(module, "GLOBAL_ENV_FILE", env_file, raising=False)
monkeypatch.setattr(
core_config.VibeConfig,
"save_updates",
classmethod(lambda cls, updates: record_updates(updates)),
)
return OnboardingApp(), env_file, saved_updates
async def pass_welcome_screen(pilot: Pilot) -> None:
welcome_screen = pilot.app.get_screen("welcome")
await _wait_for(
lambda: not welcome_screen.query_one("#enter-hint").has_class("hidden"), pilot
)
await pilot.press("enter")
await _wait_for(lambda: isinstance(pilot.app.screen, ThemeSelectionScreen), pilot)
@pytest.mark.asyncio
async def test_ui_gets_through_the_onboarding_successfully(
onboarding_app: tuple[OnboardingApp, Path, dict[str, Any]],
) -> None:
app, env_file, config_updates = onboarding_app
api_key_value = "sk-onboarding-test-key"
async with app.run_test() as pilot:
await pass_welcome_screen(pilot)
await pilot.press("enter")
await _wait_for(lambda: isinstance(app.screen, ApiKeyScreen), pilot)
api_screen = app.screen
input_widget = api_screen.query_one("#key", Input)
await pilot.press(*api_key_value)
assert input_widget.value == api_key_value
await pilot.press("enter")
await _wait_for(lambda: app.return_value is not None, pilot, timeout=2.0)
assert app.return_value == "completed"
assert env_file.is_file()
env_contents = env_file.read_text(encoding="utf-8")
assert "MISTRAL_API_KEY" in env_contents
assert api_key_value in env_contents
assert config_updates.get("textual_theme") == app.theme
@pytest.mark.asyncio
async def test_ui_can_pick_a_theme_and_saves_selection(
onboarding_app: tuple[OnboardingApp, Path, dict[str, Any]],
) -> None:
app, _, config_updates = onboarding_app
async with app.run_test() as pilot:
await pass_welcome_screen(pilot)
theme_screen = app.screen
app.post_message(
Resize(Size(40, 10), Size(40, 10))
) # trigger the resize event handler
preview = theme_screen.query_one("#preview")
assert preview.styles.max_height is not None
target_theme = "gruvbox"
assert target_theme in THEMES
start_index = THEMES.index(app.theme)
target_index = THEMES.index(target_theme)
steps_down = (target_index - start_index) % len(THEMES)
await pilot.press(*["down"] * steps_down)
assert app.theme == target_theme
await pilot.press("enter")
await _wait_for(lambda: isinstance(app.screen, ApiKeyScreen), pilot)
assert config_updates.get("textual_theme") == target_theme

View File

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 22 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 26 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 24 KiB

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
from rich.style import Style
from textual.widgets.text_area import TextAreaTheme
from tests.stubs.fake_backend import FakeBackend
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.textual_ui.widgets.chat_input import ChatTextArea
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
def default_config() -> VibeConfig:
"""Default configuration for snapshot testing.
Remove as much interference as possible from the snapshot comparison, in order to get a clean pixel-to-pixel comparison.
- Injects a fake backend to prevent (or stub) LLM calls.
- Disables the welcome banner animation.
- Forces a value for the displayed workdir
- Hides the chat input cursor (as the blinking animation is not deterministic).
"""
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
textual_theme="gruvbox",
disable_welcome_banner_animation=True,
displayed_workdir="/test/workdir",
)
class BaseSnapshotTestApp(VibeApp):
CSS_PATH = "../../vibe/cli/textual_ui/app.tcss"
def __init__(self, config: VibeConfig | None = None, **kwargs):
config = config or default_config()
super().__init__(config=config, **kwargs)
self.agent = Agent(
config,
auto_approve=self.auto_approve,
enable_streaming=self.enable_streaming,
backend=FakeBackend(),
)
async def on_mount(self) -> None:
await super().on_mount()
self._hide_chat_input_cursor()
def _hide_chat_input_cursor(self) -> None:
text_area = self.query_one(ChatTextArea)
hidden_cursor_theme = TextAreaTheme(name="hidden_cursor", cursor_style=Style())
text_area.register_theme(hidden_cursor_theme)
text_area.theme = "hidden_cursor"

View File

@@ -0,0 +1,20 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable, Iterable
from pathlib import PurePath
from typing import Protocol
from textual.app import App
from textual.pilot import Pilot
class SnapCompare(Protocol):
def __call__(
self,
app: str | PurePath | App,
/,
*,
press: Iterable[str] = ...,
terminal_size: tuple[int, int] = ...,
run_before: (Callable[[Pilot], Awaitable[None] | None] | None) = ...,
) -> bool: ...

View File

@@ -0,0 +1,43 @@
from __future__ import annotations
from textual.pilot import Pilot
from tests.mock.utils import mock_llm_chunk
from tests.snapshots.base_snapshot_test_app import BaseSnapshotTestApp, default_config
from tests.snapshots.snap_compare import SnapCompare
from tests.stubs.fake_backend import FakeBackend
from vibe.core.agent import Agent
class SnapshotTestAppWithConversation(BaseSnapshotTestApp):
def __init__(self) -> None:
config = default_config()
fake_backend = FakeBackend(
results=[
mock_llm_chunk(
content="I'm the Vibe agent and I'm ready to help.",
prompt_tokens=10_000,
completion_tokens=2_500,
)
]
)
super().__init__(config=config)
self.agent = Agent(
config,
auto_approve=self.auto_approve,
enable_streaming=self.enable_streaming,
backend=fake_backend,
)
def test_snapshot_shows_basic_conversation(snap_compare: SnapCompare) -> None:
async def run_before(pilot: Pilot) -> None:
await pilot.press(*"Hello there, who are you?")
await pilot.press("enter")
await pilot.pause(0.4)
assert snap_compare(
"test_ui_snapshot_basic_conversation.py:SnapshotTestAppWithConversation",
terminal_size=(120, 36),
run_before=run_before,
)

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
from textual.pilot import Pilot
from textual.widgets.markdown import MarkdownFence
from tests.snapshots.snap_compare import SnapCompare
from vibe.cli.textual_ui.widgets.messages import AssistantMessage
def test_snapshot_allows_horizontal_scrolling_for_long_code_blocks(
snap_compare: SnapCompare,
) -> None:
assistant_message_md = """Here's a very long print instruction:
```python
def lorem_ipsum():
# Print a very long line (Lorem Ipsum)
print("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem.")
```
The `print` statement includes a very long line of Lorem Ipsum text to demonstrate a lengthy output."""
async def run_before(pilot: Pilot) -> None:
app = pilot.app
assistant_message = AssistantMessage(assistant_message_md)
messages_area = app.query_one("#messages")
await messages_area.mount(assistant_message)
await assistant_message.write_initial_content()
await pilot.pause(0.1)
markdown_fence = app.query_one(MarkdownFence)
markdown_fence.scroll_relative(x=15, immediate=True)
assert snap_compare(
"base_snapshot_test_app.py:BaseSnapshotTestApp",
run_before=run_before,
terminal_size=(120, 36),
)

View File

@@ -0,0 +1,28 @@
from __future__ import annotations
from textual.pilot import Pilot
from tests.snapshots.base_snapshot_test_app import BaseSnapshotTestApp, default_config
from tests.snapshots.snap_compare import SnapCompare
from vibe.cli.update_notifier import FakeVersionUpdateGateway, VersionUpdate
class SnapshotTestAppWithUpdate(BaseSnapshotTestApp):
def __init__(self):
config = default_config()
config.enable_update_checks = True
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version="0.2.0")
)
super().__init__(config=config, version_update_notifier=version_update_notifier)
def test_snapshot_shows_release_update_notification(snap_compare: SnapCompare) -> None:
async def run_before(pilot: Pilot) -> None:
await pilot.pause(0.2)
assert snap_compare(
"test_ui_snapshot_release_update_notification.py:SnapshotTestAppWithUpdate",
terminal_size=(120, 36),
run_before=run_before,
)

115
tests/stubs/fake_backend.py Normal file
View File

@@ -0,0 +1,115 @@
from __future__ import annotations
from collections.abc import AsyncGenerator, Callable, Iterable
from tests.mock.utils import mock_llm_chunk
from vibe.core.types import LLMChunk, LLMMessage
class FakeBackend:
"""Minimal async backend stub to drive Agent.act without network.
Provide a finite sequence of LLMResult objects to be returned by
`complete`. When exhausted, returns an empty assistant message.
"""
def __init__(
self,
results: Iterable[LLMChunk] | None = None,
*,
token_counter: Callable[[list[LLMMessage]], int] | None = None,
exception_to_raise: Exception | None = None,
) -> None:
self._chunks = list(results or [])
self._requests_messages: list[list[LLMMessage]] = []
self._requests_extra_headers: list[dict[str, str] | None] = []
self._count_tokens_calls: list[list[LLMMessage]] = []
self._token_counter = token_counter or self._default_token_counter
self._exception_to_raise = exception_to_raise
@property
def requests_messages(self) -> list[list[LLMMessage]]:
return self._requests_messages
@property
def requests_extra_headers(self) -> list[dict[str, str] | None]:
return self._requests_extra_headers
@staticmethod
def _default_token_counter(messages: list[LLMMessage]) -> int:
return 1
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def complete(
self,
*,
model,
messages,
temperature,
tools,
tool_choice,
extra_headers,
max_tokens,
) -> LLMChunk:
if self._exception_to_raise:
raise self._exception_to_raise
self._requests_messages.append(messages)
self._requests_extra_headers.append(extra_headers)
if self._chunks:
chunk = self._chunks.pop(0)
if not self._chunks:
chunk = chunk.model_copy(update={"finish_reason": "stop"})
return chunk
return mock_llm_chunk(content="", finish_reason="stop")
async def complete_streaming(
self,
*,
model,
messages,
temperature,
tools,
tool_choice,
extra_headers,
max_tokens,
) -> AsyncGenerator[LLMChunk]:
if self._exception_to_raise:
raise self._exception_to_raise
self._requests_messages.append(messages)
self._requests_extra_headers.append(extra_headers)
has_final_chunk = False
while self._chunks:
chunk = self._chunks.pop(0)
is_last_provided_chunk = not self._chunks
if is_last_provided_chunk:
chunk = chunk.model_copy(update={"finish_reason": "stop"})
if chunk.finish_reason is not None:
has_final_chunk = True
yield chunk
if has_final_chunk:
break
if not has_final_chunk:
yield mock_llm_chunk(content="", finish_reason="stop")
async def count_tokens(
self,
*,
model,
messages,
temperature=0.0,
tools,
tool_choice=None,
extra_headers,
) -> int:
self._count_tokens_calls.append(list(messages))
return self._token_counter(messages)

View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from acp import (
Agent,
AgentSideConnection,
CreateTerminalRequest,
KillTerminalCommandRequest,
KillTerminalCommandResponse,
ReadTextFileRequest,
ReadTextFileResponse,
ReleaseTerminalRequest,
ReleaseTerminalResponse,
RequestPermissionRequest,
RequestPermissionResponse,
SessionNotification,
TerminalHandle,
TerminalOutputRequest,
TerminalOutputResponse,
WaitForTerminalExitRequest,
WaitForTerminalExitResponse,
WriteTextFileRequest,
WriteTextFileResponse,
)
class FakeAgentSideConnection(AgentSideConnection):
def __init__(self, to_agent: Callable[[AgentSideConnection], Agent]) -> None:
self._session_updates = []
to_agent(self)
async def sessionUpdate(self, params: SessionNotification) -> None:
self._session_updates.append(params)
async def requestPermission(
self, params: RequestPermissionRequest
) -> RequestPermissionResponse:
raise NotImplementedError()
async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse:
raise NotImplementedError()
async def writeTextFile(
self, params: WriteTextFileRequest
) -> WriteTextFileResponse | None:
raise NotImplementedError()
async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle:
raise NotImplementedError()
async def terminalOutput(
self, params: TerminalOutputRequest
) -> TerminalOutputResponse:
raise NotImplementedError()
async def releaseTerminal(
self, params: ReleaseTerminalRequest
) -> ReleaseTerminalResponse | None:
raise NotImplementedError()
async def waitForTerminalExit(
self, params: WaitForTerminalExitRequest
) -> WaitForTerminalExitResponse:
raise NotImplementedError()
async def killTerminal(
self, params: KillTerminalCommandRequest
) -> KillTerminalCommandResponse | None:
raise NotImplementedError()
async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError()
async def extNotification(self, method: str, params: dict[str, Any]) -> None:
raise NotImplementedError()
async def close(self) -> None:
raise NotImplementedError()
async def __aenter__(self) -> AgentSideConnection:
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
await self.close()

30
tests/stubs/fake_tool.py Normal file
View File

@@ -0,0 +1,30 @@
from __future__ import annotations
from pydantic import BaseModel
from vibe.core.tools.base import BaseTool, BaseToolConfig, BaseToolState
class FakeToolArgs(BaseModel):
pass
class FakeToolResult(BaseModel):
message: str = "fake tool executed"
class FakeToolState(BaseToolState):
pass
class FakeTool(BaseTool[FakeToolArgs, FakeToolResult, BaseToolConfig, FakeToolState]):
_exception_to_raise: BaseException | None = None
@classmethod
def get_name(cls) -> str:
return "stub_tool"
async def run(self, args: FakeToolArgs) -> FakeToolResult:
if self._exception_to_raise:
raise self._exception_to_raise
return FakeToolResult()

View File

@@ -0,0 +1,56 @@
from __future__ import annotations
import pytest
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
from vibe.core.types import (
AssistantEvent,
CompactEndEvent,
CompactStartEvent,
LLMMessage,
Role,
)
@pytest.mark.asyncio
async def test_auto_compact_triggers_and_batches_observer() -> None:
observed: list[tuple[Role, str | None]] = []
def observer(msg: LLMMessage) -> None:
observed.append((msg.role, msg.content))
backend = FakeBackend([
mock_llm_chunk(content="<summary>"),
mock_llm_chunk(content="<final>"),
])
cfg = VibeConfig(
session_logging=SessionLoggingConfig(enabled=False), auto_compact_threshold=1
)
agent = Agent(cfg, message_observer=observer, backend=backend)
agent.stats.context_tokens = 2
events = [ev async for ev in agent.act("Hello")]
assert len(events) == 3
assert isinstance(events[0], CompactStartEvent)
assert isinstance(events[1], CompactEndEvent)
assert isinstance(events[2], AssistantEvent)
start: CompactStartEvent = events[0]
end: CompactEndEvent = events[1]
final: AssistantEvent = events[2]
assert start.current_context_tokens == 2
assert start.threshold == 1
assert end.old_context_tokens == 2
assert end.new_context_tokens >= 1
assert final.content == "<final>"
roles = [r for r, _ in observed]
assert roles == [Role.system, Role.user, Role.assistant]
assert (
observed[1][1] is not None
and "Last request from user was: Hello" in observed[1][1]
)
assert observed[2][1] == "<final>"

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
import pytest
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
@pytest.fixture
def vibe_config() -> VibeConfig:
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
@pytest.mark.asyncio
async def test_passes_x_affinity_header_when_asking_an_answer(vibe_config: VibeConfig):
backend = FakeBackend([mock_llm_chunk(content="Response", finish_reason="stop")])
agent = Agent(vibe_config, backend=backend)
[_ async for _ in agent.act("Hello")]
assert len(backend.requests_extra_headers) > 0
headers = backend.requests_extra_headers[0]
assert headers is not None
assert "x-affinity" in headers
assert headers["x-affinity"] == agent.session_id
@pytest.mark.asyncio
async def test_passes_x_affinity_header_when_asking_an_answer_streaming(
vibe_config: VibeConfig,
):
backend = FakeBackend([mock_llm_chunk(content="Response", finish_reason="stop")])
agent = Agent(vibe_config, backend=backend, enable_streaming=True)
[_ async for _ in agent.act("Hello")]
assert len(backend.requests_extra_headers) > 0
headers = backend.requests_extra_headers[0]
assert headers is not None
assert "x-affinity" in headers
assert headers["x-affinity"] == agent.session_id
@pytest.mark.asyncio
async def test_updates_tokens_stats_based_on_backend_response(vibe_config: VibeConfig):
chunk = mock_llm_chunk(
content="Response",
finish_reason="stop",
prompt_tokens=100,
completion_tokens=50,
)
backend = FakeBackend([chunk])
agent = Agent(vibe_config, backend=backend)
[_ async for _ in agent.act("Hello")]
assert agent.stats.context_tokens == 150
@pytest.mark.asyncio
async def test_updates_tokens_stats_based_on_backend_response_streaming(
vibe_config: VibeConfig,
):
final_chunk = mock_llm_chunk(
content="Complete",
finish_reason="stop",
prompt_tokens=200,
completion_tokens=75,
)
backend = FakeBackend([final_chunk])
agent = Agent(vibe_config, backend=backend, enable_streaming=True)
[_ async for _ in agent.act("Hello")]
assert agent.stats.context_tokens == 275

View File

@@ -0,0 +1,430 @@
from __future__ import annotations
from collections.abc import Callable
from typing import cast
from unittest.mock import AsyncMock
import pytest
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
from vibe.core.llm.types import BackendLike
from vibe.core.middleware import (
ConversationContext,
MiddlewareAction,
MiddlewarePipeline,
MiddlewareResult,
ResetReason,
)
from vibe.core.tools.base import BaseToolConfig, ToolPermission
from vibe.core.tools.builtins.todo import TodoArgs
from vibe.core.types import (
AssistantEvent,
FunctionCall,
LLMChunk,
LLMMessage,
Role,
ToolCall,
ToolCallEvent,
ToolResultEvent,
)
from vibe.core.utils import (
ApprovalResponse,
CancellationReason,
get_user_cancellation_message,
)
class InjectBeforeMiddleware:
injectedMessage = "<injected>"
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
"Inject a message just before the current step executes."
return MiddlewareResult(
action=MiddlewareAction.INJECT_MESSAGE, message=self.injectedMessage
)
async def after_turn(self, context: ConversationContext) -> MiddlewareResult:
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
return None
def make_config(
*,
disable_logging: bool = True,
enabled_tools: list[str] | None = None,
tools: dict[str, BaseToolConfig] | None = None,
) -> VibeConfig:
cfg = VibeConfig(
session_logging=SessionLoggingConfig(enabled=not disable_logging),
auto_compact_threshold=0,
system_prompt_id="tests",
include_project_context=False,
include_prompt_detail=False,
include_model_info=False,
enabled_tools=enabled_tools or [],
tools=tools or {},
)
return cfg
@pytest.fixture
def observer_capture() -> tuple[
list[tuple[Role, str | None]], Callable[[LLMMessage], None]
]:
observed: list[tuple[Role, str | None]] = []
def observer(msg: LLMMessage) -> None:
observed.append((msg.role, msg.content))
return observed, observer
@pytest.mark.asyncio
async def test_act_flushes_batched_messages_with_injection_middleware(
observer_capture,
) -> None:
observed, observer = observer_capture
backend = FakeBackend([mock_llm_chunk(content="I can write very efficient code.")])
agent = Agent(make_config(), message_observer=observer, backend=backend)
agent.middleware_pipeline.add(InjectBeforeMiddleware())
async for _ in agent.act("How can you help?"):
pass
assert len(observed) == 3
assert [r for r, _ in observed] == [Role.system, Role.user, Role.assistant]
assert observed[0][1] == "You are Vibe, a super useful programming assistant."
# injected content should be appended to the user's message before emission
assert (
observed[1][1]
== f"How can you help?\n\n{InjectBeforeMiddleware.injectedMessage}"
)
assert observed[2][1] == "I can write very efficient code."
@pytest.mark.asyncio
async def test_stop_action_flushes_user_msg_before_returning(observer_capture) -> None:
observed, observer = observer_capture
# max_turns=0 forces an immediate STOP on the first before_turn
backend = FakeBackend([
mock_llm_chunk(content="My response will never reach you...")
])
agent = Agent(
make_config(), message_observer=observer, max_turns=0, backend=backend
)
async for _ in agent.act("Greet."):
pass
assert len(observed) == 2
# user's message should have been flushed before returning
assert [r for r, _ in observed] == [Role.system, Role.user]
assert observed[0][1] == "You are Vibe, a super useful programming assistant."
assert observed[1][1] == "Greet."
@pytest.mark.asyncio
async def test_act_emits_user_and_assistant_msgs(observer_capture) -> None:
observed, observer = observer_capture
backend = FakeBackend([mock_llm_chunk(content="Pong!")])
agent = Agent(make_config(), message_observer=observer, backend=backend)
async for _ in agent.act("Ping?"):
pass
assert len(observed) == 3
assert [r for r, _ in observed] == [Role.system, Role.user, Role.assistant]
assert observed[1][1] == "Ping?"
assert observed[2][1] == "Pong!"
@pytest.mark.asyncio
async def test_act_yields_assistant_event_with_usage_stats() -> None:
backend = FakeBackend([mock_llm_chunk(content="Pong!")])
agent = Agent(make_config(), backend=backend)
events = [ev async for ev in agent.act("Ping?")]
assert len(events) == 1
ev = events[-1]
assert isinstance(ev, AssistantEvent)
assert ev.content == "Pong!"
# stats come from tests.mock.utils.mock_llm_result (prompt=10, completion=5)
assert ev.prompt_tokens == 10
assert ev.completion_tokens == 5
assert ev.session_total_tokens == 15
@pytest.mark.asyncio
async def test_act_streams_batched_chunks_in_order() -> None:
backend = FakeBackend([
mock_llm_chunk(content="Hello"),
mock_llm_chunk(content=" from"),
mock_llm_chunk(content=" Vibe"),
mock_llm_chunk(content="! "),
mock_llm_chunk(content="More"),
mock_llm_chunk(content=" and"),
mock_llm_chunk(content=" end"),
])
agent = Agent(make_config(), backend=backend, enable_streaming=True)
events = [event async for event in agent.act("Stream, please.")]
assert len(events) == 2
assert [event.content for event in events if isinstance(event, AssistantEvent)] == [
"Hello from Vibe! More",
" and end",
]
assert agent.messages[-1].role == Role.assistant
assert agent.messages[-1].content == "Hello from Vibe! More and end"
@pytest.mark.asyncio
async def test_act_handles_streaming_with_tool_call_events_in_sequence() -> None:
todo_tool_call = ToolCall(
id="tc_stream",
index=0,
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
)
backend = FakeBackend([
mock_llm_chunk(content="Checking your todos."),
mock_llm_chunk(content="", tool_calls=[todo_tool_call]),
mock_llm_chunk(content="", finish_reason="stop"),
mock_llm_chunk(content="Done reviewing todos."),
])
agent = Agent(
make_config(
enabled_tools=["todo"],
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
),
backend=backend,
auto_approve=True,
enable_streaming=True,
)
events = [event async for event in agent.act("What about my todos?")]
assert [type(event) for event in events] == [
AssistantEvent,
ToolCallEvent,
ToolResultEvent,
AssistantEvent,
]
assert isinstance(events[0], AssistantEvent)
assert events[0].content == "Checking your todos."
assert isinstance(events[1], ToolCallEvent)
assert events[1].tool_name == "todo"
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is None
assert events[2].skipped is False
assert isinstance(events[3], AssistantEvent)
assert events[3].content == "Done reviewing todos."
assert agent.messages[-1].content == "Done reviewing todos."
@pytest.mark.asyncio
async def test_act_handles_tool_call_chunk_with_content() -> None:
todo_tool_call = ToolCall(
id="tc_content",
index=0,
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
)
backend = FakeBackend([
mock_llm_chunk(content="Preparing "),
mock_llm_chunk(content="todo request", tool_calls=[todo_tool_call]),
mock_llm_chunk(content=" complete", finish_reason="stop"),
])
agent = Agent(
make_config(
enabled_tools=["todo"],
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
),
backend=backend,
auto_approve=True,
enable_streaming=True,
)
events = [event async for event in agent.act("Check todos with content.")]
assert [type(event) for event in events] == [
AssistantEvent,
AssistantEvent,
ToolCallEvent,
ToolResultEvent,
]
assert isinstance(events[0], AssistantEvent)
assert events[0].content == "Preparing todo request"
assert isinstance(events[1], AssistantEvent)
assert events[1].content == " complete"
assert any(
m.role == Role.assistant and m.content == "Preparing todo request complete"
for m in agent.messages
)
@pytest.mark.asyncio
async def test_act_merges_streamed_tool_call_arguments() -> None:
tool_call_part_one = ToolCall(
id="tc_merge",
index=0,
function=FunctionCall(
name="todo", arguments='{"action": "read", "note": "First '
),
)
tool_call_part_two = ToolCall(
id="tc_merge", index=0, function=FunctionCall(name="todo", arguments='part"}')
)
backend = FakeBackend([
mock_llm_chunk(content="Planning: "),
mock_llm_chunk(content="", tool_calls=[tool_call_part_one]),
mock_llm_chunk(content="", tool_calls=[tool_call_part_two]),
])
agent = Agent(
make_config(
enabled_tools=["todo"],
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
),
backend=backend,
auto_approve=True,
enable_streaming=True,
)
events = [event async for event in agent.act("Merge streamed tool call args.")]
assert [type(event) for event in events] == [
AssistantEvent,
ToolCallEvent,
ToolResultEvent,
]
call_event = events[1]
assert isinstance(call_event, ToolCallEvent)
assert call_event.tool_call_id == "tc_merge"
call_args = cast(TodoArgs, call_event.args)
assert call_args.action == "read"
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is None
assert events[2].skipped is False
assistant_with_calls = next(
m for m in agent.messages if m.role == Role.assistant and m.tool_calls
)
reconstructed_calls = assistant_with_calls.tool_calls or []
assert len(reconstructed_calls) == 1
assert reconstructed_calls[0].function.arguments == (
'{"action": "read", "note": "First part"}'
)
@pytest.mark.asyncio
async def test_act_raises_when_stream_never_signals_finish() -> None:
class IncompleteStreamingBackend(BackendLike):
def __init__(self, chunks: list[LLMChunk]) -> None:
self._chunks = list(chunks)
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def complete_streaming(self, **_: object):
while self._chunks:
yield self._chunks.pop(0)
async def complete(self, **_: object):
return mock_llm_chunk(content="", finish_reason="stop")
async def count_tokens(self, **_: object) -> int:
return 0
backend = IncompleteStreamingBackend([mock_llm_chunk(content="partial")])
agent = Agent(make_config(), backend=backend, enable_streaming=True)
with pytest.raises(RuntimeError, match="Streamed completion returned no chunks"):
[event async for event in agent.act("Will this finish?")]
@pytest.mark.asyncio
async def test_act_handles_user_cancellation_during_streaming() -> None:
class CountingMiddleware(MiddlewarePipeline):
def __init__(self) -> None:
self.before_calls = 0
self.after_calls = 0
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
self.before_calls += 1
return MiddlewareResult()
async def after_turn(self, context: ConversationContext) -> MiddlewareResult:
self.after_calls += 1
return MiddlewareResult()
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
return None
todo_tool_call = ToolCall(
id="tc_cancel",
index=0,
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
)
backend = FakeBackend([
mock_llm_chunk(content="Preparing "),
mock_llm_chunk(content="todo request", tool_calls=[todo_tool_call]),
mock_llm_chunk(content="", finish_reason="stop"),
])
agent = Agent(
make_config(
enabled_tools=["todo"],
tools={"todo": BaseToolConfig(permission=ToolPermission.ASK)},
),
backend=backend,
auto_approve=False,
enable_streaming=True,
)
middleware = CountingMiddleware()
agent.middleware_pipeline.add(middleware)
agent.set_approval_callback(
lambda _name, _args, _id: (
ApprovalResponse.NO,
str(get_user_cancellation_message(CancellationReason.OPERATION_CANCELLED)),
)
)
agent.interaction_logger.save_interaction = AsyncMock(return_value=None)
events = [event async for event in agent.act("Cancel mid stream?")]
assert [type(event) for event in events] == [
AssistantEvent,
ToolCallEvent,
ToolResultEvent,
]
assert middleware.before_calls == 1
assert middleware.after_calls == 0
assert isinstance(events[-1], ToolResultEvent)
assert events[-1].skipped is True
assert events[-1].skip_reason is not None
assert "<user_cancellation>" in events[-1].skip_reason
assert agent.interaction_logger.save_interaction.await_count == 2
@pytest.mark.asyncio
async def test_act_flushes_and_logs_when_streaming_errors(observer_capture) -> None:
observed, observer = observer_capture
backend = FakeBackend(exception_to_raise=RuntimeError("boom in streaming"))
agent = Agent(
make_config(), backend=backend, message_observer=observer, enable_streaming=True
)
agent.interaction_logger.save_interaction = AsyncMock(return_value=None)
with pytest.raises(RuntimeError, match="boom in streaming"):
[_ async for _ in agent.act("Trigger stream failure")]
assert [role for role, _ in observed] == [Role.system, Role.user]
assert agent.interaction_logger.save_interaction.await_count == 1

711
tests/test_agent_stats.py Normal file
View File

@@ -0,0 +1,711 @@
from __future__ import annotations
from collections.abc import Callable
import pytest
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from vibe.core.agent import Agent
from vibe.core.config import (
Backend,
ModelConfig,
ProviderConfig,
SessionLoggingConfig,
VibeConfig,
)
from vibe.core.tools.base import BaseToolConfig, ToolPermission
from vibe.core.types import (
AgentStats,
AssistantEvent,
CompactEndEvent,
CompactStartEvent,
FunctionCall,
LLMMessage,
Role,
ToolCall,
)
def make_config(
*,
system_prompt_id: str = "tests",
active_model: str = "devstral-latest",
input_price: float = 0.4,
output_price: float = 2.0,
disable_logging: bool = True,
auto_compact_threshold: int = 0,
include_project_context: bool = False,
include_prompt_detail: bool = False,
enabled_tools: list[str] | None = None,
todo_permission: ToolPermission = ToolPermission.ALWAYS,
) -> VibeConfig:
models = [
ModelConfig(
name="mistral-vibe-cli-latest",
provider="mistral",
alias="devstral-latest",
input_price=input_price,
output_price=output_price,
),
ModelConfig(
name="devstral-small-latest",
provider="mistral",
alias="devstral-small",
input_price=0.1,
output_price=0.3,
),
ModelConfig(
name="strawberry",
provider="lechat",
alias="strawberry",
input_price=2.5,
output_price=10.0,
),
]
providers = [
ProviderConfig(
name="mistral",
api_base="https://api.mistral.ai/v1",
api_key_env_var="MISTRAL_API_KEY",
backend=Backend.MISTRAL,
),
ProviderConfig(
name="lechat",
api_base="https://api.mistral.ai/v1",
api_key_env_var="LECHAT_API_KEY",
backend=Backend.MISTRAL,
),
]
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=not disable_logging),
auto_compact_threshold=auto_compact_threshold,
system_prompt_id=system_prompt_id,
include_project_context=include_project_context,
include_prompt_detail=include_prompt_detail,
active_model=active_model,
models=models,
providers=providers,
enabled_tools=enabled_tools or [],
tools={"todo": BaseToolConfig(permission=todo_permission)},
)
@pytest.fixture
def observer_capture() -> tuple[list[LLMMessage], Callable[[LLMMessage], None]]:
observed: list[LLMMessage] = []
def observer(msg: LLMMessage) -> None:
observed.append(msg)
return observed, observer
class TestAgentStatsHelpers:
def test_update_pricing(self) -> None:
stats = AgentStats()
stats.update_pricing(1.5, 3.0)
assert stats.input_price_per_million == 1.5
assert stats.output_price_per_million == 3.0
def test_reset_context_state_preserves_cumulative(self) -> None:
stats = AgentStats(
steps=5,
session_prompt_tokens=1000,
session_completion_tokens=500,
tool_calls_succeeded=3,
tool_calls_failed=1,
context_tokens=800,
last_turn_prompt_tokens=100,
last_turn_completion_tokens=50,
last_turn_duration=1.5,
tokens_per_second=33.3,
input_price_per_million=0.4,
output_price_per_million=2.0,
)
stats.reset_context_state()
assert stats.steps == 5
assert stats.session_prompt_tokens == 1000
assert stats.session_completion_tokens == 500
assert stats.tool_calls_succeeded == 3
assert stats.tool_calls_failed == 1
assert stats.input_price_per_million == 0.4
assert stats.output_price_per_million == 2.0
assert stats.context_tokens == 0
assert stats.last_turn_prompt_tokens == 0
assert stats.last_turn_completion_tokens == 0
assert stats.last_turn_duration == 0.0
assert stats.tokens_per_second == 0.0
def test_session_cost_computed_from_current_pricing(self) -> None:
stats = AgentStats(
session_prompt_tokens=1_000_000,
session_completion_tokens=500_000,
input_price_per_million=1.0,
output_price_per_million=2.0,
)
# Cost = 1M * $1/M + 0.5M * $2/M = $1 + $1 = $2
assert stats.session_cost == 2.0
stats.update_pricing(2.0, 4.0)
# Cost = 1M * $2/M + 0.5M * $4/M = $2 + $2 = $4
assert stats.session_cost == 4.0
class TestReloadPreservesStats:
@pytest.mark.asyncio
async def test_reload_preserves_session_tokens(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="First response", finish_reason="stop")
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Hello"):
pass
old_session_prompt = agent.stats.session_prompt_tokens
old_session_completion = agent.stats.session_completion_tokens
assert old_session_prompt > 0
assert old_session_completion > 0
await agent.reload_with_initial_messages()
assert agent.stats.session_prompt_tokens == old_session_prompt
assert agent.stats.session_completion_tokens == old_session_completion
@pytest.mark.asyncio
async def test_reload_preserves_tool_call_stats(self) -> None:
backend = FakeBackend([
mock_llm_chunk(
content="Calling tool",
tool_calls=[
ToolCall(
id="tc1",
function=FunctionCall(
name="todo", arguments='{"action": "read"}'
),
)
],
),
mock_llm_chunk(content="Done", finish_reason="stop"),
])
config = make_config(enabled_tools=["todo"])
agent = Agent(config, auto_approve=True, backend=backend)
async for _ in agent.act("Check todos"):
pass
assert agent.stats.tool_calls_succeeded == 1
assert agent.stats.tool_calls_agreed == 1
await agent.reload_with_initial_messages()
assert agent.stats.tool_calls_succeeded == 1
assert agent.stats.tool_calls_agreed == 1
@pytest.mark.asyncio
async def test_reload_preserves_steps(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="R1", finish_reason="stop"),
mock_llm_chunk(content="R2", finish_reason="stop"),
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("First"):
pass
async for _ in agent.act("Second"):
pass
old_steps = agent.stats.steps
assert old_steps >= 2
await agent.reload_with_initial_messages()
assert agent.stats.steps == old_steps
@pytest.mark.asyncio
async def test_reload_preserves_context_tokens_when_messages_preserved(
self,
) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(), backend=backend)
[_ async for _ in agent.act("Hello")]
assert agent.stats.context_tokens > 0
initial_context_tokens = agent.stats.context_tokens
assert len(agent.messages) > 1
await agent.reload_with_initial_messages()
assert len(agent.messages) > 1
assert agent.stats.context_tokens == initial_context_tokens
@pytest.mark.asyncio
async def test_reload_resets_context_tokens_when_no_messages(self) -> None:
backend = FakeBackend([])
agent = Agent(make_config(), backend=backend)
assert len(agent.messages) == 1
assert agent.stats.context_tokens == 0
await agent.reload_with_initial_messages()
assert len(agent.messages) == 1
assert agent.stats.context_tokens == 0
@pytest.mark.asyncio
async def test_reload_resets_context_tokens_when_system_prompt_changes(
self,
) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
config1 = make_config(system_prompt_id="tests")
config2 = make_config(system_prompt_id="cli")
agent = Agent(config1, backend=backend)
[_ async for _ in agent.act("Hello")]
assert agent.stats.context_tokens > 0
assert len(agent.messages) > 1
await agent.reload_with_initial_messages(config=config2)
assert len(agent.messages) > 1
assert agent.stats.context_tokens == 0
@pytest.mark.asyncio
async def test_reload_updates_pricing_from_new_model(self, monkeypatch) -> None:
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
config_mistral = make_config(active_model="devstral-latest")
agent = Agent(config_mistral, backend=backend)
async for _ in agent.act("Hello"):
pass
assert agent.stats.input_price_per_million == 0.4
assert agent.stats.output_price_per_million == 2.0
config_other = make_config(active_model="strawberry")
await agent.reload_with_initial_messages(config=config_other)
assert agent.stats.input_price_per_million == 2.5
assert agent.stats.output_price_per_million == 10.0
@pytest.mark.asyncio
async def test_reload_accumulates_tokens_across_configs(self, monkeypatch) -> None:
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
backend = FakeBackend([
mock_llm_chunk(content="First", finish_reason="stop"),
mock_llm_chunk(content="After reload", finish_reason="stop"),
])
config1 = make_config(active_model="devstral-latest")
agent = Agent(config1, backend=backend)
async for _ in agent.act("Hello"):
pass
tokens_after_first = (
agent.stats.session_prompt_tokens + agent.stats.session_completion_tokens
)
config2 = make_config(active_model="strawberry")
await agent.reload_with_initial_messages(config=config2)
async for _ in agent.act("Continue"):
pass
tokens_after_second = (
agent.stats.session_prompt_tokens + agent.stats.session_completion_tokens
)
assert tokens_after_second > tokens_after_first
class TestReloadPreservesMessages:
@pytest.mark.asyncio
async def test_reload_preserves_conversation_messages(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Hello"):
pass
assert len(agent.messages) == 3
old_user_content = agent.messages[1].content
old_assistant_content = agent.messages[2].content
await agent.reload_with_initial_messages()
assert len(agent.messages) == 3
assert agent.messages[0].role == Role.system
assert agent.messages[1].role == Role.user
assert agent.messages[1].content == old_user_content
assert agent.messages[2].role == Role.assistant
assert agent.messages[2].content == old_assistant_content
@pytest.mark.asyncio
async def test_reload_updates_system_prompt_preserves_rest(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
config1 = make_config(system_prompt_id="tests")
agent = Agent(config1, backend=backend)
async for _ in agent.act("Hello"):
pass
old_system = agent.messages[0].content
old_user = agent.messages[1].content
config2 = make_config(system_prompt_id="cli")
await agent.reload_with_initial_messages(config=config2)
assert agent.messages[0].content != old_system
assert agent.messages[1].content == old_user
@pytest.mark.asyncio
async def test_reload_with_no_messages_stays_empty(self) -> None:
backend = FakeBackend([])
agent = Agent(make_config(), backend=backend)
assert len(agent.messages) == 1
await agent.reload_with_initial_messages()
assert len(agent.messages) == 1
assert agent.messages[0].role == Role.system
@pytest.mark.asyncio
async def test_reload_notifies_observer_with_all_messages(
self, observer_capture
) -> None:
observed, observer = observer_capture
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(), message_observer=observer, backend=backend)
async for _ in agent.act("Hello"):
pass
observed.clear()
await agent.reload_with_initial_messages()
assert len(observed) == 3
assert observed[0].role == Role.system
assert observed[1].role == Role.user
assert observed[2].role == Role.assistant
class TestCompactStatsHandling:
@pytest.mark.asyncio
async def test_compact_preserves_cumulative_stats(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="First response", finish_reason="stop"),
mock_llm_chunk(content="<summary>", finish_reason="stop"),
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Build something"):
pass
tokens_before_compact = agent.stats.session_prompt_tokens
completions_before = agent.stats.session_completion_tokens
steps_before = agent.stats.steps
await agent.compact()
# Cumulative stats include the compact turn
assert agent.stats.session_prompt_tokens > tokens_before_compact
assert agent.stats.session_completion_tokens > completions_before
assert agent.stats.steps > steps_before
@pytest.mark.asyncio
async def test_compact_updates_context_tokens(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Long response " * 100, finish_reason="stop"),
mock_llm_chunk(content="<summary>", finish_reason="stop"),
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Do something complex"):
pass
context_before = agent.stats.context_tokens
await agent.compact()
assert agent.stats.context_tokens < context_before
@pytest.mark.asyncio
async def test_compact_preserves_tool_call_stats(self) -> None:
backend = FakeBackend([
mock_llm_chunk(
content="Using tool",
tool_calls=[
ToolCall(
id="tc1",
function=FunctionCall(
name="todo", arguments='{"action": "read"}'
),
)
],
),
mock_llm_chunk(content="Done", finish_reason="stop"),
mock_llm_chunk(content="<summary>", finish_reason="stop"),
])
config = make_config(enabled_tools=["todo"])
agent = Agent(config, auto_approve=True, backend=backend)
async for _ in agent.act("Check todos"):
pass
assert agent.stats.tool_calls_succeeded == 1
await agent.compact()
assert agent.stats.tool_calls_succeeded == 1
@pytest.mark.asyncio
async def test_compact_resets_session_id(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Long response " * 100, finish_reason="stop"),
mock_llm_chunk(content="<summary>", finish_reason="stop"),
])
agent = Agent(make_config(disable_logging=False), backend=backend)
original_session_id = agent.session_id
original_logger_session_id = agent.interaction_logger.session_id
assert agent.session_id == original_logger_session_id
async for _ in agent.act("Do something complex"):
pass
await agent.compact()
assert agent.session_id != original_session_id
assert agent.session_id == agent.interaction_logger.session_id
class TestAutoCompactIntegration:
@pytest.mark.asyncio
async def test_auto_compact_triggers_and_preserves_stats(self) -> None:
observed: list[tuple[Role, str | None]] = []
def observer(msg: LLMMessage) -> None:
observed.append((msg.role, msg.content))
backend = FakeBackend([
mock_llm_chunk(content="<summary>", finish_reason="stop"),
mock_llm_chunk(content="<final>", finish_reason="stop"),
])
cfg = VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
auto_compact_threshold=1,
)
agent = Agent(cfg, message_observer=observer, backend=backend)
agent.stats.context_tokens = 2
events = [ev async for ev in agent.act("Hello")]
assert len(events) == 3
assert isinstance(events[0], CompactStartEvent)
assert isinstance(events[1], CompactEndEvent)
assert isinstance(events[2], AssistantEvent)
start: CompactStartEvent = events[0]
end: CompactEndEvent = events[1]
final: AssistantEvent = events[2]
assert start.current_context_tokens == 2
assert start.threshold == 1
assert end.old_context_tokens == 2
assert end.new_context_tokens >= 1
assert final.content == "<final>"
roles = [r for r, _ in observed]
assert roles == [Role.system, Role.user, Role.assistant]
assert (
observed[1][1] is not None
and "Last request from user was: Hello" in observed[1][1]
)
class TestClearHistoryFullReset:
@pytest.mark.asyncio
async def test_clear_history_fully_resets_stats(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Hello"):
pass
assert agent.stats.session_prompt_tokens > 0
assert agent.stats.steps > 0
await agent.clear_history()
assert agent.stats.session_prompt_tokens == 0
assert agent.stats.session_completion_tokens == 0
assert agent.stats.steps == 0
@pytest.mark.asyncio
async def test_clear_history_preserves_pricing(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
config = make_config(input_price=0.4, output_price=2.0)
agent = Agent(config, backend=backend)
async for _ in agent.act("Hello"):
pass
await agent.clear_history()
assert agent.stats.input_price_per_million == 0.4
assert agent.stats.output_price_per_million == 2.0
@pytest.mark.asyncio
async def test_clear_history_removes_messages(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Hello"):
pass
assert len(agent.messages) == 3
await agent.clear_history()
assert len(agent.messages) == 1
assert agent.messages[0].role == Role.system
@pytest.mark.asyncio
async def test_clear_history_resets_session_id(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
agent = Agent(make_config(disable_logging=False), backend=backend)
original_session_id = agent.session_id
original_logger_session_id = agent.interaction_logger.session_id
assert agent.session_id == original_logger_session_id
async for _ in agent.act("Hello"):
pass
await agent.clear_history()
assert agent.session_id != original_session_id
assert agent.session_id == agent.interaction_logger.session_id
class TestStatsEdgeCases:
@pytest.mark.asyncio
async def test_session_cost_approximation_on_model_change(
self, monkeypatch
) -> None:
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
backend = FakeBackend([
mock_llm_chunk(content="Response", finish_reason="stop")
])
config1 = make_config(active_model="devstral-latest")
agent = Agent(config1, backend=backend)
async for _ in agent.act("Hello"):
pass
cost_before = agent.stats.session_cost
config2 = make_config(active_model="strawberry")
await agent.reload_with_initial_messages(config=config2)
cost_after = agent.stats.session_cost
assert cost_after > cost_before
@pytest.mark.asyncio
async def test_multiple_reloads_accumulate_correctly(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="R1", finish_reason="stop"),
mock_llm_chunk(content="R2", finish_reason="stop"),
mock_llm_chunk(content="R3", finish_reason="stop"),
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("First"):
pass
tokens1 = agent.stats.session_total_llm_tokens
await agent.reload_with_initial_messages()
async for _ in agent.act("Second"):
pass
tokens2 = agent.stats.session_total_llm_tokens
await agent.reload_with_initial_messages()
async for _ in agent.act("Third"):
pass
tokens3 = agent.stats.session_total_llm_tokens
assert tokens1 < tokens2 < tokens3
@pytest.mark.asyncio
async def test_compact_then_reload_preserves_both(self) -> None:
backend = FakeBackend([
mock_llm_chunk(content="Initial response", finish_reason="stop"),
mock_llm_chunk(content="<summary>", finish_reason="stop"),
mock_llm_chunk(content="After reload", finish_reason="stop"),
])
agent = Agent(make_config(), backend=backend)
async for _ in agent.act("Build something"):
pass
await agent.compact()
tokens_after_compact = agent.stats.session_prompt_tokens
await agent.reload_with_initial_messages()
assert agent.stats.session_prompt_tokens == tokens_after_compact
async for _ in agent.act("Continue"):
pass
assert agent.stats.session_prompt_tokens > tokens_after_compact
@pytest.mark.asyncio
async def test_reload_without_config_preserves_current(self) -> None:
backend = FakeBackend([])
original_config = make_config(active_model="devstral-latest")
agent = Agent(original_config, backend=backend)
await agent.reload_with_initial_messages(config=None)
assert agent.config.active_model == "devstral-latest"
@pytest.mark.asyncio
async def test_reload_with_new_config_updates_it(self) -> None:
backend = FakeBackend([])
original_config = make_config(active_model="devstral-latest")
agent = Agent(original_config, backend=backend)
new_config = make_config(active_model="devstral-small")
await agent.reload_with_initial_messages(config=new_config)
assert agent.config.active_model == "devstral-small"

View File

@@ -0,0 +1,477 @@
from __future__ import annotations
import asyncio
import json
from typing import Any
import pytest
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from tests.stubs.fake_tool import FakeTool
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
from vibe.core.tools.base import BaseToolConfig, ToolPermission
from vibe.core.tools.builtins.todo import TodoItem
from vibe.core.types import (
AssistantEvent,
BaseEvent,
FunctionCall,
LLMMessage,
Role,
SyncApprovalCallback,
ToolCall,
ToolCallEvent,
ToolResultEvent,
)
from vibe.core.utils import ApprovalResponse
async def act_and_collect_events(agent: Agent, prompt: str) -> list[BaseEvent]:
return [ev async for ev in agent.act(prompt)]
def make_config(todo_permission: ToolPermission = ToolPermission.ALWAYS) -> VibeConfig:
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
auto_compact_threshold=0,
enabled_tools=["todo"],
tools={"todo": BaseToolConfig(permission=todo_permission)},
system_prompt_id="tests",
include_project_context=False,
include_prompt_detail=False,
)
def make_todo_tool_call(call_id: str, action: str = "read") -> ToolCall:
return ToolCall(
id=call_id,
function=FunctionCall(name="todo", arguments=f'{{"action": "{action}"}}'),
)
def make_agent(
*,
auto_approve: bool = True,
todo_permission: ToolPermission = ToolPermission.ALWAYS,
backend: FakeBackend,
approval_callback: SyncApprovalCallback | None = None,
) -> Agent:
agent = Agent(
make_config(todo_permission=todo_permission),
auto_approve=auto_approve,
backend=backend,
)
if approval_callback:
agent.set_approval_callback(approval_callback)
return agent
@pytest.mark.asyncio
async def test_single_tool_call_executes_under_auto_approve() -> None:
mocked_tool_call_id = "call_1"
tool_call = make_todo_tool_call(mocked_tool_call_id)
backend = FakeBackend([
mock_llm_chunk(content="Let me check your todos.", tool_calls=[tool_call]),
mock_llm_chunk(content="I retrieved 0 todos.", finish_reason="stop"),
])
agent = make_agent(auto_approve=True, backend=backend)
events = await act_and_collect_events(agent, "What's my todo list?")
assert [type(e) for e in events] == [
AssistantEvent,
ToolCallEvent,
ToolResultEvent,
AssistantEvent,
]
assert isinstance(events[0], AssistantEvent)
assert events[0].content == "Let me check your todos."
assert isinstance(events[1], ToolCallEvent)
assert events[1].tool_name == "todo"
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is None
assert events[2].skipped is False
assert events[2].result is not None
assert isinstance(events[3], AssistantEvent)
assert events[3].content == "I retrieved 0 todos."
# check conversation history
tool_msgs = [m for m in agent.messages if m.role == Role.tool]
assert len(tool_msgs) == 1
assert tool_msgs[-1].tool_call_id == mocked_tool_call_id
assert "total_count" in (tool_msgs[-1].content or "")
@pytest.mark.asyncio
async def test_tool_call_requires_approval_if_not_auto_approved() -> None:
agent = make_agent(
auto_approve=False,
todo_permission=ToolPermission.ASK,
backend=FakeBackend([
mock_llm_chunk(
content="Let me check your todos.",
tool_calls=[make_todo_tool_call("call_2")],
),
mock_llm_chunk(
content="I cannot execute the tool without approval.",
finish_reason="stop",
),
]),
)
events = await act_and_collect_events(agent, "What's my todo list?")
assert isinstance(events[1], ToolCallEvent)
assert events[1].tool_name == "todo"
assert isinstance(events[2], ToolResultEvent)
assert events[2].skipped is True
assert events[2].error is None
assert events[2].result is None
assert events[2].skip_reason is not None
assert "not permitted" in events[2].skip_reason.lower()
assert isinstance(events[3], AssistantEvent)
assert events[3].content == "I cannot execute the tool without approval."
assert agent.stats.tool_calls_rejected == 1
assert agent.stats.tool_calls_agreed == 0
assert agent.stats.tool_calls_succeeded == 0
@pytest.mark.asyncio
async def test_tool_call_approved_by_callback() -> None:
def approval_callback(
_tool_name: str, _args: dict[str, Any], _tool_call_id: str
) -> tuple[str, str | None]:
return (ApprovalResponse.YES, None)
agent = make_agent(
auto_approve=False,
todo_permission=ToolPermission.ASK,
approval_callback=approval_callback,
backend=FakeBackend([
mock_llm_chunk(
content="Let me check your todos.",
tool_calls=[make_todo_tool_call("call_3")],
),
mock_llm_chunk(content="I retrieved 0 todos.", finish_reason="stop"),
]),
)
events = await act_and_collect_events(agent, "What's my todo list?")
assert isinstance(events[2], ToolResultEvent)
assert events[2].skipped is False
assert events[2].error is None
assert events[2].result is not None
assert agent.stats.tool_calls_agreed == 1
assert agent.stats.tool_calls_rejected == 0
assert agent.stats.tool_calls_succeeded == 1
@pytest.mark.asyncio
async def test_tool_call_rejected_when_auto_approve_disabled_and_rejected_by_callback() -> (
None
):
custom_feedback = "User declined tool execution"
def approval_callback(
_tool_name: str, _args: dict[str, Any], _tool_call_id: str
) -> tuple[str, str | None]:
return (ApprovalResponse.NO, custom_feedback)
agent = make_agent(
auto_approve=False,
todo_permission=ToolPermission.ASK,
approval_callback=approval_callback,
backend=FakeBackend([
mock_llm_chunk(
content="Let me check your todos.",
tool_calls=[make_todo_tool_call("call_4")],
),
mock_llm_chunk(
content="Understood, I won't check the todos.", finish_reason="stop"
),
]),
)
events = await act_and_collect_events(agent, "What's my todo list?")
assert isinstance(events[2], ToolResultEvent)
assert events[2].skipped is True
assert events[2].error is None
assert events[2].result is None
assert events[2].skip_reason == custom_feedback
assert agent.stats.tool_calls_rejected == 1
assert agent.stats.tool_calls_agreed == 0
assert agent.stats.tool_calls_succeeded == 0
@pytest.mark.asyncio
async def test_tool_call_skipped_when_permission_is_never() -> None:
agent = make_agent(
auto_approve=False,
todo_permission=ToolPermission.NEVER,
backend=FakeBackend([
mock_llm_chunk(
content="Let me check your todos.",
tool_calls=[make_todo_tool_call("call_never")],
),
mock_llm_chunk(content="Tool is disabled.", finish_reason="stop"),
]),
)
events = await act_and_collect_events(agent, "What's my todo list?")
assert isinstance(events[2], ToolResultEvent)
assert events[2].skipped is True
assert events[2].error is None
assert events[2].result is None
assert events[2].skip_reason is not None
assert "permanently disabled" in events[2].skip_reason.lower()
tool_msgs = [m for m in agent.messages if m.role == Role.tool and m.name == "todo"]
assert len(tool_msgs) == 1
assert tool_msgs[0].name == "todo"
assert events[2].skip_reason in (tool_msgs[-1].content or "")
assert agent.stats.tool_calls_rejected == 1
assert agent.stats.tool_calls_agreed == 0
assert agent.stats.tool_calls_succeeded == 0
@pytest.mark.asyncio
async def test_approval_always_flips_auto_approve_for_subsequent_calls() -> None:
callback_invocations = []
def approval_callback(
tool_name: str, _args: dict[str, Any], _tool_call_id: str
) -> tuple[str, str | None]:
callback_invocations.append(tool_name)
return (ApprovalResponse.ALWAYS, None)
agent = make_agent(
auto_approve=False,
todo_permission=ToolPermission.ASK,
approval_callback=approval_callback,
backend=FakeBackend([
mock_llm_chunk(
content="First check.", tool_calls=[make_todo_tool_call("call_first")]
),
mock_llm_chunk(content="First done.", finish_reason="stop"),
mock_llm_chunk(
content="Second check.", tool_calls=[make_todo_tool_call("call_second")]
),
mock_llm_chunk(content="Second done.", finish_reason="stop"),
]),
)
events1 = await act_and_collect_events(agent, "First request")
events2 = await act_and_collect_events(agent, "Second request")
assert agent.auto_approve is True
assert len(callback_invocations) == 1
assert callback_invocations[0] == "todo"
assert isinstance(events1[2], ToolResultEvent)
assert events1[2].skipped is False
assert events1[2].result is not None
assert isinstance(events2[2], ToolResultEvent)
assert events2[2].skipped is False
assert events2[2].result is not None
assert agent.stats.tool_calls_rejected == 0
assert agent.stats.tool_calls_succeeded == 2
@pytest.mark.asyncio
async def test_tool_call_with_invalid_action() -> None:
tool_call = ToolCall(
id="call_5",
function=FunctionCall(name="todo", arguments='{"action": "invalid_action"}'),
)
agent = make_agent(
auto_approve=True,
backend=FakeBackend([
mock_llm_chunk(content="Let me check your todos.", tool_calls=[tool_call]),
mock_llm_chunk(
content="I encountered an error with the action.", finish_reason="stop"
),
]),
)
events = await act_and_collect_events(agent, "What's my todo list?")
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is not None
assert events[2].result is None
assert "tool_error" in events[2].error.lower()
assert agent.stats.tool_calls_failed == 1
@pytest.mark.asyncio
async def test_tool_call_with_duplicate_todo_ids() -> None:
duplicate_todos = [
TodoItem(id="duplicate", content="Task 1"),
TodoItem(id="duplicate", content="Task 2"),
]
tool_call = ToolCall(
id="call_6",
function=FunctionCall(
name="todo",
arguments=json.dumps({
"action": "write",
"todos": [t.model_dump() for t in duplicate_todos],
}),
),
)
agent = make_agent(
auto_approve=True,
backend=FakeBackend([
mock_llm_chunk(content="Let me write todos.", tool_calls=[tool_call]),
mock_llm_chunk(
content="I couldn't write todos with duplicate IDs.",
finish_reason="stop",
),
]),
)
events = await act_and_collect_events(agent, "Add todos")
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is not None
assert events[2].result is None
assert "unique" in events[2].error.lower()
assert agent.stats.tool_calls_failed == 1
@pytest.mark.asyncio
async def test_tool_call_with_exceeding_max_todos() -> None:
many_todos = [TodoItem(id=f"todo_{i}", content=f"Task {i}") for i in range(150)]
tool_call = ToolCall(
id="call_7",
function=FunctionCall(
name="todo",
arguments=json.dumps({
"action": "write",
"todos": [t.model_dump() for t in many_todos],
}),
),
)
agent = make_agent(
auto_approve=True,
backend=FakeBackend([
mock_llm_chunk(content="Let me write todos.", tool_calls=[tool_call]),
mock_llm_chunk(
content="I couldn't write that many todos.", finish_reason="stop"
),
]),
)
events = await act_and_collect_events(agent, "Add todos")
assert isinstance(events[2], ToolResultEvent)
assert events[2].error is not None
assert events[2].result is None
assert "100" in events[2].error
assert agent.stats.tool_calls_failed == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"exception_class",
[
pytest.param(KeyboardInterrupt, id="keyboard_interrupt"),
pytest.param(asyncio.CancelledError, id="asyncio_cancelled"),
],
)
async def test_tool_call_can_be_interrupted(
exception_class: type[BaseException],
) -> None:
tool_call = ToolCall(
id="call_8", function=FunctionCall(name="stub_tool", arguments="{}")
)
config = VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
auto_compact_threshold=0,
enabled_tools=["stub_tool"],
)
agent = Agent(
config,
auto_approve=True,
backend=FakeBackend([
mock_llm_chunk(content="Let me use the tool.", tool_calls=[tool_call]),
mock_llm_chunk(content="Tool execution completed.", finish_reason="stop"),
]),
)
# no dependency injection available => monkey patch
agent.tool_manager._available["stub_tool"] = FakeTool
stub_tool_instance = agent.tool_manager.get("stub_tool")
assert isinstance(stub_tool_instance, FakeTool)
stub_tool_instance._exception_to_raise = exception_class()
events: list[BaseEvent] = []
with pytest.raises(exception_class):
async for ev in agent.act("Execute tool"):
events.append(ev)
tool_result_event = next(
(e for e in events if isinstance(e, ToolResultEvent)), None
)
assert tool_result_event is not None
assert tool_result_event.error is not None
assert "execution interrupted by user" in tool_result_event.error.lower()
@pytest.mark.asyncio
async def test_fill_missing_tool_responses_inserts_placeholders() -> None:
agent = Agent(
make_config(),
auto_approve=True,
backend=FakeBackend([mock_llm_chunk(content="ok", finish_reason="stop")]),
)
tool_calls_messages = [
ToolCall(
id="tc1", function=FunctionCall(name="todo", arguments='{"action": "read"}')
),
ToolCall(
id="tc2", function=FunctionCall(name="todo", arguments='{"action": "read"}')
),
]
assistant_msg = LLMMessage(
role=Role.assistant, content="Calling tools...", tool_calls=tool_calls_messages
)
agent.messages = [
agent.messages[0],
assistant_msg,
# only one tool responded: the second is missing
LLMMessage(
role=Role.tool, tool_call_id="tc1", name="todo", content="Retrieved 0 todos"
),
]
await act_and_collect_events(agent, "Proceed")
tool_msgs = [m for m in agent.messages if m.role == Role.tool]
assert any(m.tool_call_id == "tc2" for m in tool_msgs)
# find placeholder message for tc2
placeholder = next(m for m in tool_msgs if m.tool_call_id == "tc2")
assert placeholder.name == "todo"
assert (
placeholder.content
== "<user_cancellation>Tool execution interrupted - no response available</user_cancellation>"
)
@pytest.mark.asyncio
async def test_ensure_assistant_after_tool_appends_understood() -> None:
agent = Agent(
make_config(),
auto_approve=True,
backend=FakeBackend([mock_llm_chunk(content="ok", finish_reason="stop")]),
)
tool_msg = LLMMessage(
role=Role.tool, tool_call_id="tc_z", name="todo", content="Done"
)
agent.messages = [agent.messages[0], tool_msg]
await act_and_collect_events(agent, "Next")
# find the seeded tool message and ensure the next message is "Understood."
idx = next(i for i, m in enumerate(agent.messages) if m.role == Role.tool)
assert agent.messages[idx + 1].role == Role.assistant
assert agent.messages[idx + 1].content == "Understood."

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
import pytest
from tests.mock.mock_backend_factory import mock_backend_factory
from tests.mock.utils import mock_llm_chunk
from tests.stubs.fake_backend import FakeBackend
from vibe.core import run_programmatic
from vibe.core.config import Backend, SessionLoggingConfig, VibeConfig
from vibe.core.types import LLMMessage, OutputFormat, Role
class SpyStreamingFormatter:
def __init__(self) -> None:
self.emitted: list[tuple[Role, str | None]] = []
def on_message_added(self, message: LLMMessage) -> None:
self.emitted.append((message.role, message.content))
def on_event(self, _event) -> None: # No-op for this test
pass
def finalize(self) -> str | None:
return None
def test_run_programmatic_preload_streaming_is_batched(
monkeypatch: pytest.MonkeyPatch,
) -> None:
spy = SpyStreamingFormatter()
monkeypatch.setattr(
"vibe.core.programmatic.create_formatter", lambda *_args, **_kwargs: spy
)
with mock_backend_factory(
Backend.MISTRAL,
lambda provider, **kwargs: FakeBackend([
mock_llm_chunk(
content="Decorators are wrappers that modify function behavior.",
finish_reason="stop",
)
]),
):
cfg = VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
system_prompt_id="tests",
include_project_context=False,
include_prompt_detail=False,
include_model_info=False,
)
previous = [
LLMMessage(
role=Role.system, content="This system message should be ignored."
),
LLMMessage(
role=Role.user, content="Previously, you told me about decorators."
),
LLMMessage(
role=Role.assistant,
content="Sure, decorators allow you to wrap functions.",
),
]
run_programmatic(
config=cfg,
prompt="Can you summarize what decorators are?",
output_format=OutputFormat.STREAMING,
previous_messages=previous,
)
roles = [r for r, _ in spy.emitted]
assert roles == [
Role.system,
Role.user,
Role.assistant,
Role.user,
Role.assistant,
]
assert (
spy.emitted[0][1] == "You are Vibe, a super useful programming assistant."
)
assert spy.emitted[1][1] == "Previously, you told me about decorators."
assert spy.emitted[2][1] == "Sure, decorators allow you to wrap functions."
assert spy.emitted[3][1] == "Can you summarize what decorators are?"
assert (
spy.emitted[4][1]
== "Decorators are wrappers that modify function behavior."
)
def test_run_programmatic_ignores_system_messages_in_previous(
monkeypatch: pytest.MonkeyPatch,
) -> None:
spy = SpyStreamingFormatter()
monkeypatch.setattr(
"vibe.core.programmatic.create_formatter", lambda *_args, **_kwargs: spy
)
with mock_backend_factory(
Backend.MISTRAL,
lambda provider, **kwargs: FakeBackend([mock_llm_chunk(content="Understood.")]),
):
cfg = VibeConfig(
session_logging=SessionLoggingConfig(enabled=False),
system_prompt_id="tests",
include_project_context=False,
include_prompt_detail=False,
include_model_info=False,
)
run_programmatic(
config=cfg,
prompt="Let's move on to practical examples.",
output_format=OutputFormat.STREAMING,
previous_messages=[
LLMMessage(
role=Role.system,
content="First system message that should be ignored.",
),
LLMMessage(role=Role.user, content="Continue our previous discussion."),
LLMMessage(
role=Role.system,
content="Second system message that should be ignored.",
),
],
auto_approve=True,
)
roles = [r for r, _ in spy.emitted]
assert roles == [Role.system, Role.user, Role.user, Role.assistant]
assert (
spy.emitted[0][1] == "You are Vibe, a super useful programming assistant."
)
assert spy.emitted[1][1] == "Continue our previous discussion."
assert spy.emitted[2][1] == "Let's move on to practical examples."
assert spy.emitted[3][1] == "Understood."

View File

@@ -0,0 +1,101 @@
from __future__ import annotations
import json
from pathlib import Path
from vibe.cli.history_manager import HistoryManager
def test_history_manager_normalizes_loaded_entries_like_numbers_to_strings(
tmp_path: Path,
) -> None:
# ideally, we would not use real I/O; but this test is a quick bugfix, thus it
# does not intend to refactor the HistoryManager
history_file = tmp_path / "history.jsonl"
history_entries = ["hello", 123]
history_file.write_text(
"\n".join(json.dumps(entry) for entry in history_entries) + "\n",
encoding="utf-8",
)
manager = HistoryManager(history_file)
result = manager.get_previous(current_input="", prefix="1")
assert result == "123"
def test_history_manager_retains_a_fixed_number_of_entries(tmp_path: Path) -> None:
history_file = tmp_path / "history.jsonl"
manager = HistoryManager(history_file, max_entries=3)
manager.add("first")
manager.add("second")
manager.add("third")
manager.add("fourth")
reloaded = HistoryManager(history_file)
assert reloaded.get_previous(current_input="", prefix="") == "fourth"
assert reloaded.get_previous(current_input="", prefix="") == "third"
assert reloaded.get_previous(current_input="", prefix="") == "second"
# "first" is not proposed as we defined number of entries to 3
assert reloaded.get_previous(current_input="", prefix="") is None
def test_history_manager_filters_invalid_and_duplicated_entries(tmp_path: Path) -> None:
history_file = tmp_path / "history.jsonl"
manager = HistoryManager(history_file, max_entries=5)
manager.add("") # empty
manager.add(" ") # is trimmed
manager.add("first")
manager.add("second")
manager.add("second") # duplicate
manager.add("third")
reloaded = HistoryManager(history_file)
assert reloaded.get_previous(current_input="", prefix="") == "third"
assert reloaded.get_previous(current_input="", prefix="") == "second"
assert reloaded.get_previous(current_input="", prefix="") == "first"
assert reloaded.get_previous(current_input="", prefix="") is None
assert reloaded.get_previous(current_input="", prefix="") is None
def test_history_manager_filters_commands(tmp_path: Path) -> None:
history_file = tmp_path / "history.jsonl"
manager = HistoryManager(history_file, max_entries=5)
manager.add("first")
manager.add("/skip")
reloaded = HistoryManager(history_file)
assert reloaded.get_previous(current_input="", prefix="/") == None
assert reloaded.get_previous(current_input="", prefix="") == "first"
assert reloaded.get_previous(current_input="", prefix="") is None
def test_history_manager_allows_navigation_round_trip(tmp_path: Path) -> None:
history_file = tmp_path / "history.jsonl"
manager = HistoryManager(history_file)
manager.add("alpha")
manager.add("beta")
assert manager.get_previous(current_input="typed") == "beta"
assert manager.get_previous(current_input="typed") == "alpha"
assert manager.get_next() == "beta"
assert manager.get_next() == "typed"
assert manager.get_next() is None
def test_history_manager_prefix_filtering(tmp_path: Path) -> None:
history_file = tmp_path / "history.jsonl"
manager = HistoryManager(history_file)
manager.add("foo")
manager.add("bar")
manager.add("fizz")
assert manager.get_previous(current_input="", prefix="f") == "fizz"
assert manager.get_previous(current_input="", prefix="f") == "foo"
assert manager.get_previous(current_input="", prefix="f") is None

View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import sys
import pytest
from vibe.core.config import VibeConfig
from vibe.core.system_prompt import get_universal_system_prompt
from vibe.core.tools.manager import ToolManager
def test_get_universal_system_prompt_includes_windows_prompt_on_windows(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(sys, "platform", "win32")
monkeypatch.setenv("COMSPEC", "C:\\Windows\\System32\\cmd.exe")
config = VibeConfig(
system_prompt_id="tests",
include_project_context=False,
include_prompt_detail=True,
include_model_info=False,
)
tool_manager = ToolManager(config)
prompt = get_universal_system_prompt(tool_manager, config)
assert "You are Vibe, a super useful programming assistant." in prompt
assert (
"The operating system is Windows with shell `C:\\Windows\\System32\\cmd.exe`"
in prompt
)
assert "DO NOT use Unix commands like `ls`, `grep`, `cat`" in prompt
assert "Use: `dir` (Windows) for directory listings" in prompt
assert "Use: backslashes (\\\\) for paths" in prompt
assert "Check command availability with: `where command` (Windows)" in prompt
assert "Script shebang: Not applicable on Windows" in prompt

107
tests/test_tagged_text.py Normal file
View File

@@ -0,0 +1,107 @@
from __future__ import annotations
import pytest
from vibe.core.utils import CANCELLATION_TAG, KNOWN_TAGS, TaggedText
def test_tagged_text_creation_without_tag() -> None:
tagged = TaggedText("Hello world")
assert tagged.message == "Hello world"
assert tagged.tag == ""
assert str(tagged) == "Hello world"
def test_tagged_text_creation_with_tag() -> None:
tagged = TaggedText("User cancelled", CANCELLATION_TAG)
assert tagged.message == "User cancelled"
assert tagged.tag == CANCELLATION_TAG
assert str(tagged) == f"<{CANCELLATION_TAG}>User cancelled</{CANCELLATION_TAG}>"
@pytest.mark.parametrize("tag", KNOWN_TAGS)
def test_tagged_text_from_string_with_known_tag(tag: str) -> None:
text = f"<{tag}>This is a tagged text</{tag}>"
tagged = TaggedText.from_string(text)
assert tagged.message == "This is a tagged text"
assert tagged.tag == tag
@pytest.mark.parametrize("tag", KNOWN_TAGS)
def test_tagged_text_from_string_with_known_tag_multiline(tag: str) -> None:
text = f"<{tag}>This is a tagged text</{tag}>"
tagged = TaggedText.from_string(text)
assert tagged.message == "This is a tagged text"
assert tagged.tag == tag
@pytest.mark.parametrize("tag", KNOWN_TAGS)
def test_tagged_text_from_string_with_known_tag_whitespace(tag: str) -> None:
text = f"<{tag}> This is a tagged text </{tag}>"
tagged = TaggedText.from_string(text)
assert tagged.message == " This is a tagged text "
assert tagged.tag == tag
def test_tagged_text_from_string_with_unknown_tag() -> None:
text = "<unknown_tag>Some content</unknown_tag>"
tagged = TaggedText.from_string(text)
assert tagged.message == "<unknown_tag>Some content</unknown_tag>"
assert tagged.tag == ""
def test_tagged_text_from_string_with_text_before_tag() -> None:
text = f"Prefix text <{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}>"
tagged = TaggedText.from_string(text)
assert tagged.message == "Prefix text Content"
assert tagged.tag == CANCELLATION_TAG
def test_tagged_text_from_string_with_text_after_tag() -> None:
text = f"<{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}> Suffix text"
tagged = TaggedText.from_string(text)
assert tagged.message == "Content Suffix text"
assert tagged.tag == CANCELLATION_TAG
def test_tagged_text_from_string_with_text_before_and_after_tag() -> None:
text = f"Before <{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}> After"
tagged = TaggedText.from_string(text)
assert tagged.message == "Before Content After"
assert tagged.tag == CANCELLATION_TAG
def test_tagged_text_from_string_without_tags() -> None:
text = "Just plain text without any tags"
tagged = TaggedText.from_string(text)
assert tagged.message == "Just plain text without any tags"
assert tagged.tag == ""
def test_tagged_text_from_string_empty() -> None:
tagged = TaggedText.from_string("")
assert tagged.message == ""
assert tagged.tag == ""
def test_tagged_text_from_string_mismatched_tags() -> None:
text = f"<{CANCELLATION_TAG}>Content</different_tag>"
tagged = TaggedText.from_string(text)
assert tagged.message == f"<{CANCELLATION_TAG}>Content</different_tag>"
assert tagged.tag == ""
def test_tagged_text_round_trip() -> None:
original = TaggedText("User cancelled", CANCELLATION_TAG)
text = str(original)
parsed = TaggedText.from_string(text)
assert parsed.message == original.message
assert parsed.tag == original.tag
def test_tagged_text_round_trip_no_tag() -> None:
original = TaggedText("Plain message")
text = str(original)
parsed = TaggedText.from_string(text)
assert parsed.message == original.message
assert parsed.tag == original.tag

View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import json
from pathlib import Path
import pytest
from vibe.cli.history_manager import HistoryManager
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.textual_ui.widgets.chat_input.body import ChatInputBody
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
from vibe.core.config import SessionLoggingConfig, VibeConfig
@pytest.fixture
def vibe_config() -> VibeConfig:
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
@pytest.fixture
def vibe_app(vibe_config: VibeConfig, tmp_path: Path) -> VibeApp:
return VibeApp(config=vibe_config)
@pytest.fixture
def history_file(tmp_path: Path) -> Path:
history_file = tmp_path / "history.jsonl"
history_entries = ["hello", "hi there", "how are you?"]
history_file.write_text(
"\n".join(json.dumps(entry) for entry in history_entries) + "\n",
encoding="utf-8",
)
return history_file
def inject_history_file(vibe_app: VibeApp, history_file: Path) -> None:
# Dependency Injection would help here, but as we don't have it yet: manual injection
chat_input_body = vibe_app.query_one(ChatInputBody)
chat_input_body.history = HistoryManager(history_file)
@pytest.mark.asyncio
async def test_ui_navigation_through_input_history(
vibe_app: VibeApp, history_file: Path
) -> None:
async with vibe_app.run_test() as pilot:
inject_history_file(vibe_app, history_file)
chat_input = vibe_app.query_one(ChatInputContainer)
await pilot.press("up")
assert chat_input.value == "how are you?"
await pilot.press("up")
assert chat_input.value == "hi there"
await pilot.press("up")
assert chat_input.value == "hello"
await pilot.press("up")
# cannot go further up
assert chat_input.value == "hello"
await pilot.press("down")
assert chat_input.value == "hi there"
await pilot.press("down")
assert chat_input.value == "how are you?"
await pilot.press("down")
assert chat_input.value == ""
@pytest.mark.asyncio
async def test_ui_does_nothing_if_command_completion_is_active(
vibe_app: VibeApp, history_file: Path
) -> None:
async with vibe_app.run_test() as pilot:
inject_history_file(vibe_app, history_file)
chat_input = vibe_app.query_one(ChatInputContainer)
await pilot.press("/")
assert chat_input.value == "/"
await pilot.press("up")
assert chat_input.value == "/"
await pilot.press("down")
assert chat_input.value == "/"
@pytest.mark.asyncio
async def test_ui_does_not_prevent_arrow_down_to_move_cursor_to_bottom_lines(
vibe_app: VibeApp,
):
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
textarea = chat_input.input_widget
assert textarea is not None
await pilot.press(*"test")
await pilot.press("ctrl+j", "ctrl+j")
assert chat_input.value == "test\n\n"
assert textarea.text.count("\n") == 2
initial_row = textarea.cursor_location[0]
assert initial_row == 2, f"Expected cursor on line 2, got line {initial_row}"
await pilot.press("up")
assert textarea.cursor_location[0] == 1, "First arrow up should move to line 1"
await pilot.press("up")
assert textarea.cursor_location[0] == 0, (
"Second arrow up should move to line 0 (first line)"
)
await pilot.press("down")
final_row = textarea.cursor_location[0]
assert final_row == 1, f"cursor is still on line {final_row}."
@pytest.mark.asyncio
async def test_ui_resumes_arrow_down_after_manual_move(
vibe_app: VibeApp, tmp_path: Path
) -> None:
history_path = tmp_path / "history.jsonl"
history_path.write_text(
json.dumps("first line\nsecond line") + "\n", encoding="utf-8"
)
async with vibe_app.run_test() as pilot:
inject_history_file(vibe_app, history_path)
chat_input = vibe_app.query_one(ChatInputContainer)
textarea = chat_input.input_widget
assert textarea is not None
await pilot.press("up")
assert chat_input.value == "first line\nsecond line"
assert textarea.cursor_location == (0, len("first line"))
await pilot.press("left")
await pilot.press("down")
assert textarea.cursor_location[0] == 1
assert chat_input.value == "first line\nsecond line"

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator, Callable
import time
from types import SimpleNamespace
import pytest
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
from vibe.cli.textual_ui.widgets.messages import InterruptMessage, UserMessage
from vibe.core.agent import Agent
from vibe.core.config import SessionLoggingConfig, VibeConfig
from vibe.core.types import BaseEvent
async def _wait_for(
pilot, condition: Callable[[], object | None], timeout: float = 3.0
) -> object | None:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
result = condition()
if result:
return result
await pilot.pause(0.05)
return None
class StubAgent(Agent):
def __init__(self) -> None:
self.messages: list = []
self.stats = SimpleNamespace(context_tokens=0)
self.approval_callback = None
async def initialize(self) -> None:
return
async def act(self, msg: str) -> AsyncGenerator[BaseEvent]:
if False:
yield msg
@pytest.fixture
def vibe_config() -> VibeConfig:
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=False), enable_update_checks=False
)
@pytest.fixture
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
return VibeApp(config=vibe_config)
def _patch_delayed_init(
monkeypatch: pytest.MonkeyPatch, init_event: asyncio.Event
) -> None:
async def _fake_initialize(self: VibeApp) -> None:
if self.agent or self._agent_initializing:
return
self._agent_initializing = True
try:
await init_event.wait()
self.agent = StubAgent()
except asyncio.CancelledError:
self.agent = None
return
finally:
self._agent_initializing = False
self._agent_init_task = None
monkeypatch.setattr(VibeApp, "_initialize_agent", _fake_initialize, raising=True)
@pytest.mark.asyncio
async def test_shows_user_message_as_pending_until_agent_is_initialized(
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
) -> None:
init_event = asyncio.Event()
_patch_delayed_init(monkeypatch, init_event)
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "Hello"
press_task = asyncio.create_task(pilot.press("enter"))
user_message = await _wait_for(
pilot, lambda: next(iter(vibe_app.query(UserMessage)), None)
)
assert isinstance(user_message, UserMessage)
assert user_message.has_class("pending")
init_event.set()
await press_task
assert not user_message.has_class("pending")
@pytest.mark.asyncio
async def test_can_interrupt_pending_message_during_initialization(
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
) -> None:
init_event = asyncio.Event()
_patch_delayed_init(monkeypatch, init_event)
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "Hello"
press_task = asyncio.create_task(pilot.press("enter"))
user_message = await _wait_for(
pilot, lambda: next(iter(vibe_app.query(UserMessage)), None)
)
assert isinstance(user_message, UserMessage)
assert user_message.has_class("pending")
await pilot.press("escape")
await press_task
assert not user_message.has_class("pending")
assert vibe_app.query(InterruptMessage)
assert vibe_app.agent is None
@pytest.mark.asyncio
async def test_retry_initialization_after_interrupt(
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
) -> None:
init_event = asyncio.Event()
_patch_delayed_init(monkeypatch, init_event)
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "First Message"
press_task = asyncio.create_task(pilot.press("enter"))
await _wait_for(pilot, lambda: next(iter(vibe_app.query(UserMessage)), None))
await pilot.press("escape")
await press_task
assert vibe_app.agent is None
assert vibe_app._agent_init_task is None
chat_input.value = "Second Message"
press_task_2 = asyncio.create_task(pilot.press("enter"))
def get_second_message():
messages = list(vibe_app.query(UserMessage))
if len(messages) >= 2:
return messages[-1]
return None
user_message_2 = await _wait_for(pilot, get_second_message)
assert isinstance(user_message_2, UserMessage)
assert user_message_2.has_class("pending")
assert vibe_app.agent is None
init_event.set()
await press_task_2
assert not user_message_2.has_class("pending")
assert vibe_app.agent is not None

86
tests/tools/test_bash.py Normal file
View File

@@ -0,0 +1,86 @@
from __future__ import annotations
import pytest
from vibe.core.tools.base import BaseToolState, ToolError, ToolPermission
from vibe.core.tools.builtins.bash import Bash, BashArgs, BashToolConfig
@pytest.fixture
def bash(tmp_path):
config = BashToolConfig(workdir=tmp_path)
return Bash(config=config, state=BaseToolState())
@pytest.mark.asyncio
async def test_runs_echo_successfully(bash):
result = await bash.run(BashArgs(command="echo hello"))
assert result.returncode == 0
assert result.stdout == "hello\n"
assert result.stderr == ""
@pytest.mark.asyncio
async def test_fails_cat_command_with_missing_file(bash):
with pytest.raises(ToolError) as err:
await bash.run(BashArgs(command="cat missing_file.txt"))
message = str(err.value)
assert "Command failed" in message
assert "Return code: 1" in message
assert "No such file or directory" in message
@pytest.mark.asyncio
async def test_uses_effective_workdir(tmp_path):
config = BashToolConfig(workdir=tmp_path)
bash_tool = Bash(config=config, state=BaseToolState())
result = await bash_tool.run(BashArgs(command="pwd"))
assert result.stdout.strip() == str(tmp_path)
@pytest.mark.asyncio
async def test_handles_timeout(bash):
with pytest.raises(ToolError) as err:
await bash.run(BashArgs(command="sleep 2", timeout=1))
assert "Command timed out after 1s" in str(err.value)
@pytest.mark.asyncio
async def test_truncates_output_to_max_bytes(bash):
config = BashToolConfig(workdir=None, max_output_bytes=5)
bash_tool = Bash(config=config, state=BaseToolState())
result = await bash_tool.run(BashArgs(command="printf 'abcdefghij'"))
assert result.stdout == "abcde"
assert result.stderr == ""
assert result.returncode == 0
@pytest.mark.asyncio
async def test_decodes_non_utf8_bytes(bash):
result = await bash.run(BashArgs(command="printf '\\xff\\xfe'"))
# accept both possible encodings, as some shells emit escaped bytes as literal strings
assert result.stdout in {"<EFBFBD><EFBFBD>", "\xff\xfe", r"\xff\xfe"}
assert result.stderr == ""
def test_check_allowlist_denylist():
config = BashToolConfig(allowlist=["echo", "pwd"], denylist=["rm"])
bash_tool = Bash(config=config, state=BaseToolState())
allowlisted = bash_tool.check_allowlist_denylist(BashArgs(command="echo hi"))
denylisted = bash_tool.check_allowlist_denylist(BashArgs(command="rm -rf /tmp"))
mixed = bash_tool.check_allowlist_denylist(BashArgs(command="pwd && whoami"))
empty = bash_tool.check_allowlist_denylist(BashArgs(command=""))
assert allowlisted is ToolPermission.ALWAYS
assert denylisted is ToolPermission.NEVER
assert mixed is None
assert empty is None

347
tests/tools/test_grep.py Normal file
View File

@@ -0,0 +1,347 @@
from __future__ import annotations
import shutil
import pytest
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.grep import (
Grep,
GrepArgs,
GrepBackend,
GrepState,
GrepToolConfig,
)
@pytest.fixture
def grep(tmp_path):
config = GrepToolConfig(workdir=tmp_path)
return Grep(config=config, state=GrepState())
@pytest.fixture
def grep_gnu_only(tmp_path, monkeypatch):
original_which = shutil.which
def mock_which(cmd):
if cmd == "rg":
return None
return original_which(cmd)
monkeypatch.setattr("shutil.which", mock_which)
config = GrepToolConfig(workdir=tmp_path)
return Grep(config=config, state=GrepState())
def test_detects_ripgrep_when_available(grep):
if shutil.which("rg"):
assert grep._detect_backend() == GrepBackend.RIPGREP
def test_falls_back_to_gnu_grep(grep, monkeypatch):
original_which = shutil.which
def mock_which(cmd):
if cmd == "rg":
return None
return original_which(cmd)
monkeypatch.setattr("shutil.which", mock_which)
if shutil.which("grep"):
assert grep._detect_backend() == GrepBackend.GNU_GREP
def test_raises_error_if_no_grep_available(grep, monkeypatch):
monkeypatch.setattr("shutil.which", lambda cmd: None)
with pytest.raises(ToolError) as err:
grep._detect_backend()
assert "Neither ripgrep (rg) nor grep is installed" in str(err.value)
@pytest.mark.asyncio
async def test_finds_pattern_in_file(grep, tmp_path):
(tmp_path / "test.py").write_text("def hello():\n print('world')\n")
result = await grep.run(GrepArgs(pattern="hello"))
assert result.match_count == 1
assert "hello" in result.matches
assert "test.py" in result.matches
assert not result.was_truncated
@pytest.mark.asyncio
async def test_finds_multiple_matches(grep, tmp_path):
(tmp_path / "test.py").write_text("foo\nbar\nfoo\nbaz\nfoo\n")
result = await grep.run(GrepArgs(pattern="foo"))
assert result.match_count == 3
assert result.matches.count("foo") == 3
assert not result.was_truncated
@pytest.mark.asyncio
async def test_returns_empty_on_no_matches(grep, tmp_path):
(tmp_path / "test.py").write_text("def hello():\n pass\n")
result = await grep.run(GrepArgs(pattern="nonexistent"))
assert result.match_count == 0
assert result.matches == ""
assert not result.was_truncated
@pytest.mark.asyncio
async def test_fails_with_empty_pattern(grep):
with pytest.raises(ToolError) as err:
await grep.run(GrepArgs(pattern=""))
assert "Empty search pattern" in str(err.value)
@pytest.mark.asyncio
async def test_fails_with_nonexistent_path(grep):
with pytest.raises(ToolError) as err:
await grep.run(GrepArgs(pattern="test", path="nonexistent"))
assert "Path does not exist" in str(err.value)
@pytest.mark.asyncio
async def test_searches_in_specific_path(grep, tmp_path):
subdir = tmp_path / "subdir"
subdir.mkdir()
(subdir / "test.py").write_text("match here\n")
(tmp_path / "other.py").write_text("match here too\n")
result = await grep.run(GrepArgs(pattern="match", path="subdir"))
assert result.match_count == 1
assert "subdir" in result.matches and "test.py" in result.matches
assert "other.py" not in result.matches
@pytest.mark.asyncio
async def test_truncates_to_max_matches(grep, tmp_path):
(tmp_path / "test.py").write_text("\n".join(f"line {i}" for i in range(200)))
result = await grep.run(GrepArgs(pattern="line", max_matches=50))
assert result.match_count == 50
assert result.was_truncated
@pytest.mark.asyncio
async def test_truncates_to_max_output_bytes(grep, tmp_path):
config = GrepToolConfig(workdir=tmp_path, max_output_bytes=100)
grep_tool = Grep(config=config, state=GrepState())
(tmp_path / "test.py").write_text("\n".join("x" * 100 for _ in range(10)))
result = await grep_tool.run(GrepArgs(pattern="x"))
assert len(result.matches) <= 100
assert result.was_truncated
@pytest.mark.asyncio
async def test_respects_default_ignore_patterns(grep, tmp_path):
(tmp_path / "included.py").write_text("match\n")
node_modules = tmp_path / "node_modules"
node_modules.mkdir()
(node_modules / "excluded.js").write_text("match\n")
result = await grep.run(GrepArgs(pattern="match"))
assert "included.py" in result.matches
assert "excluded.js" not in result.matches
@pytest.mark.asyncio
async def test_respects_vibeignore_file(grep, tmp_path):
(tmp_path / ".vibeignore").write_text("custom_dir/\n*.tmp\n")
custom_dir = tmp_path / "custom_dir"
custom_dir.mkdir()
(custom_dir / "excluded.py").write_text("match\n")
(tmp_path / "excluded.tmp").write_text("match\n")
(tmp_path / "included.py").write_text("match\n")
result = await grep.run(GrepArgs(pattern="match"))
assert "included.py" in result.matches
assert "excluded.py" not in result.matches
assert "excluded.tmp" not in result.matches
@pytest.mark.asyncio
async def test_ignores_comments_in_vibeignore(grep, tmp_path):
(tmp_path / ".vibeignore").write_text("# comment\npattern/\n# another comment\n")
(tmp_path / "file.py").write_text("match\n")
result = await grep.run(GrepArgs(pattern="match"))
assert result.match_count >= 1
@pytest.mark.asyncio
async def test_tracks_search_history(grep, tmp_path):
(tmp_path / "test.py").write_text("content\n")
await grep.run(GrepArgs(pattern="first"))
await grep.run(GrepArgs(pattern="second"))
await grep.run(GrepArgs(pattern="third"))
assert grep.state.search_history == ["first", "second", "third"]
@pytest.mark.asyncio
async def test_uses_effective_workdir(tmp_path):
config = GrepToolConfig(workdir=tmp_path)
grep_tool = Grep(config=config, state=GrepState())
(tmp_path / "test.py").write_text("match\n")
result = await grep_tool.run(GrepArgs(pattern="match", path="."))
assert result.match_count == 1
assert "test.py" in result.matches
@pytest.mark.skipif(not shutil.which("grep"), reason="GNU grep not available")
class TestGnuGrepBackend:
@pytest.mark.asyncio
async def test_finds_pattern_in_file(self, grep_gnu_only, tmp_path):
(tmp_path / "test.py").write_text("def hello():\n print('world')\n")
result = await grep_gnu_only.run(GrepArgs(pattern="hello"))
assert result.match_count == 1
assert "hello" in result.matches
assert "test.py" in result.matches
@pytest.mark.asyncio
async def test_finds_multiple_matches(self, grep_gnu_only, tmp_path):
(tmp_path / "test.py").write_text("foo\nbar\nfoo\nbaz\nfoo\n")
result = await grep_gnu_only.run(GrepArgs(pattern="foo"))
assert result.match_count == 3
assert result.matches.count("foo") == 3
@pytest.mark.asyncio
async def test_returns_empty_on_no_matches(self, grep_gnu_only, tmp_path):
(tmp_path / "test.py").write_text("def hello():\n pass\n")
result = await grep_gnu_only.run(GrepArgs(pattern="nonexistent"))
assert result.match_count == 0
assert result.matches == ""
@pytest.mark.asyncio
async def test_case_insensitive_for_lowercase_pattern(
self, grep_gnu_only, tmp_path
):
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
result = await grep_gnu_only.run(GrepArgs(pattern="hello"))
assert result.match_count == 3
@pytest.mark.asyncio
async def test_case_sensitive_for_mixed_case_pattern(self, grep_gnu_only, tmp_path):
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
result = await grep_gnu_only.run(GrepArgs(pattern="Hello"))
assert result.match_count == 1
@pytest.mark.asyncio
async def test_respects_exclude_patterns(self, grep_gnu_only, tmp_path):
(tmp_path / "included.py").write_text("match\n")
node_modules = tmp_path / "node_modules"
node_modules.mkdir()
(node_modules / "excluded.js").write_text("match\n")
result = await grep_gnu_only.run(GrepArgs(pattern="match"))
assert "included.py" in result.matches
assert "excluded.js" not in result.matches
@pytest.mark.asyncio
async def test_searches_in_specific_path(self, grep_gnu_only, tmp_path):
subdir = tmp_path / "subdir"
subdir.mkdir()
(subdir / "test.py").write_text("match here\n")
(tmp_path / "other.py").write_text("match here too\n")
result = await grep_gnu_only.run(GrepArgs(pattern="match", path="subdir"))
assert result.match_count == 1
assert "other.py" not in result.matches
@pytest.mark.asyncio
async def test_respects_vibeignore_file(self, grep_gnu_only, tmp_path):
(tmp_path / ".vibeignore").write_text("custom_dir/\n*.tmp\n")
custom_dir = tmp_path / "custom_dir"
custom_dir.mkdir()
(custom_dir / "excluded.py").write_text("match\n")
(tmp_path / "excluded.tmp").write_text("match\n")
(tmp_path / "included.py").write_text("match\n")
result = await grep_gnu_only.run(GrepArgs(pattern="match"))
assert "included.py" in result.matches
assert "excluded.py" not in result.matches
assert "excluded.tmp" not in result.matches
@pytest.mark.asyncio
async def test_truncates_to_max_matches(self, grep_gnu_only, tmp_path):
(tmp_path / "test.py").write_text("\n".join(f"line {i}" for i in range(200)))
result = await grep_gnu_only.run(GrepArgs(pattern="line", max_matches=50))
assert result.match_count == 50
assert result.was_truncated
@pytest.mark.skipif(not shutil.which("rg"), reason="ripgrep not available")
class TestRipgrepBackend:
@pytest.mark.asyncio
async def test_smart_case_lowercase_pattern(self, grep, tmp_path):
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
result = await grep.run(GrepArgs(pattern="hello"))
assert result.match_count == 3
@pytest.mark.asyncio
async def test_smart_case_mixed_case_pattern(self, grep, tmp_path):
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
result = await grep.run(GrepArgs(pattern="Hello"))
assert result.match_count == 1
@pytest.mark.asyncio
async def test_searches_ignored_files_when_use_default_ignore_false(
self, grep, tmp_path
):
(tmp_path / ".ignore").write_text("ignored_by_rg/\n")
ignored_dir = tmp_path / "ignored_by_rg"
ignored_dir.mkdir()
(ignored_dir / "file.py").write_text("match\n")
(tmp_path / "included.py").write_text("match\n")
result_with_ignore = await grep.run(GrepArgs(pattern="match"))
assert "included.py" in result_with_ignore.matches
assert "ignored_by_rg" not in result_with_ignore.matches
result_without_ignore = await grep.run(
GrepArgs(pattern="match", use_default_ignore=False)
)
assert "included.py" in result_without_ignore.matches
assert "ignored_by_rg/file.py" in result_without_ignore.matches

View File

@@ -0,0 +1,138 @@
from __future__ import annotations
from pathlib import Path
import time
import pytest
from textual.widgets import Static
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
from vibe.cli.textual_ui.widgets.messages import BashOutputMessage, ErrorMessage
from vibe.core.config import SessionLoggingConfig, VibeConfig
@pytest.fixture
def vibe_config(tmp_path: Path) -> VibeConfig:
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=False), workdir=tmp_path
)
@pytest.fixture
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
return VibeApp(config=vibe_config)
async def _wait_for_bash_output_message(
vibe_app: VibeApp, pilot, timeout: float = 1.0
) -> BashOutputMessage:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if message := next(iter(vibe_app.query(BashOutputMessage)), None):
return message
await pilot.pause(0.05)
raise TimeoutError(f"BashOutputMessage did not appear within {timeout}s")
def assert_no_command_error(vibe_app: VibeApp) -> None:
errors = list(vibe_app.query(ErrorMessage))
if not errors:
return
disallowed = {
"Command failed",
"Command timed out",
"No command provided after '!'",
}
offending = [
getattr(err, "_error", "")
for err in errors
if getattr(err, "_error", "")
and any(phrase in getattr(err, "_error", "") for phrase in disallowed)
]
assert not offending, f"Unexpected command errors: {offending}"
@pytest.mark.asyncio
async def test_ui_reports_no_output(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!true"
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
output_widget = message.query_one(".bash-output", Static)
assert str(output_widget.render()) == "(no output)"
assert_no_command_error(vibe_app)
@pytest.mark.asyncio
async def test_ui_shows_success_in_case_of_zero_code(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!true"
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
icon = message.query_one(".bash-exit-success", Static)
assert str(icon.render()) == ""
assert not list(message.query(".bash-exit-failure"))
@pytest.mark.asyncio
async def test_ui_shows_failure_in_case_of_non_zero_code(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!bash -lc 'exit 7'"
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
icon = message.query_one(".bash-exit-failure", Static)
assert str(icon.render()) == ""
code = message.query_one(".bash-exit-code", Static)
assert "7" in str(code.render())
assert not list(message.query(".bash-exit-success"))
@pytest.mark.asyncio
async def test_ui_handles_non_utf8_output(vibe_app: VibeApp) -> None:
"""Assert the UI accepts decoding a non-UTF8 sequence like `printf '\xf0\x9f\x98'`.
Whereas `printf '\xf0\x9f\x98\x8b'` prints a smiley face (😋) and would work even without those changes.
"""
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!printf '\\xff\\xfe'"
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
output_widget = message.query_one(".bash-output", Static)
# accept both possible encodings, as some shells emit escaped bytes as literal strings
assert str(output_widget.render()) in {"<EFBFBD><EFBFBD>", "\xff\xfe", r"\xff\xfe"}
assert_no_command_error(vibe_app)
@pytest.mark.asyncio
async def test_ui_handles_utf8_output(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!echo hello"
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
output_widget = message.query_one(".bash-output", Static)
assert str(output_widget.render()) == "hello\n"
assert_no_command_error(vibe_app)
@pytest.mark.asyncio
async def test_ui_handles_non_utf8_stderr(vibe_app: VibeApp) -> None:
async with vibe_app.run_test() as pilot:
chat_input = vibe_app.query_one(ChatInputContainer)
chat_input.value = "!bash -lc \"printf '\\\\xff\\\\xfe' 1>&2\""
await pilot.press("enter")
message = await _wait_for_bash_output_message(vibe_app, pilot)
output_widget = message.query_one(".bash-output", Static)
assert str(output_widget.render()) == "<EFBFBD><EFBFBD>"
assert_no_command_error(vibe_app)

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
from collections.abc import Callable
import httpx
import pytest
from vibe.cli.update_notifier.github_version_update_gateway import (
GitHubVersionUpdateGateway,
)
from vibe.cli.update_notifier.version_update_gateway import (
VersionUpdateGatewayCause,
VersionUpdateGatewayError,
)
Handler = Callable[[httpx.Request], httpx.Response]
GITHUB_API_URL = "https://api.github.com"
def _raise_connect_timeout(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectTimeout("boom", request=request)
@pytest.mark.asyncio
async def test_retrieves_latest_version_when_available() -> None:
def handler(request: httpx.Request) -> httpx.Response:
assert request.headers.get("Authorization") == "Bearer token"
return httpx.Response(
status_code=httpx.codes.OK,
json=[{"tag_name": "v1.2.3", "prerelease": False, "draft": False}],
)
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway(
"owner", "repo", token="token", client=client
)
update = await notifier.fetch_update()
assert update is not None
assert update.latest_version == "1.2.3"
@pytest.mark.asyncio
async def test_strips_uppercase_prefix_from_tag_name() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
status_code=httpx.codes.OK,
json=[{"tag_name": "V0.9.0", "prerelease": False, "draft": False}],
)
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
update = await notifier.fetch_update()
assert update is not None
assert update.latest_version == "0.9.0"
@pytest.mark.asyncio
async def test_considers_no_update_available_when_no_releases_are_found() -> None:
"""If the repository cannot be accessed (e.g. invalid token), the response will be 404.
But using API 'releases/latest', if no release has been created, the response will ALSO be 404.
This test ensures that we consider no update available when no releases are found.
(And this is why we are using "releases" with a per_page=1 parameter, instead of "releases/latest")
"""
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=httpx.codes.OK, json=[])
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
update = await notifier.fetch_update()
assert update is None
@pytest.mark.asyncio
async def test_considers_no_update_available_when_only_drafts_and_prereleases_are_found() -> (
None
):
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
status_code=httpx.codes.OK,
json=[
{"tag_name": "v2.0.0-beta", "prerelease": True, "draft": False},
{"tag_name": "v2.0.0", "prerelease": False, "draft": True},
],
)
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
update = await notifier.fetch_update()
assert update is None
@pytest.mark.asyncio
async def test_picks_the_most_recently_published_non_prerelease_and_non_draft() -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(
status_code=httpx.codes.OK,
json=[
{
"tag_name": "v2.0.0-beta",
"prerelease": True,
"draft": False,
"published_at": "2025-10-25T112:00:00Z",
},
{
"tag_name": "v2.0.0",
"prerelease": False,
"draft": True,
"published_at": "2025-10-26T112:00:00Z",
},
{
"tag_name": "v1.12.455",
"prerelease": False,
"draft": False,
"published_at": "2025-11-02T112:00:00Z",
},
{
"tag_name": "1.12.400",
"prerelease": False,
"draft": False,
"published_at": "2025-11-10T112:00:00Z",
},
{
"tag_name": "1.12.300",
"prerelease": False,
"draft": False,
"published_at": "2025-11-11T112:00:00Z",
},
],
)
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
update = await notifier.fetch_update()
assert update is not None
assert update.latest_version == "1.12.300"
@pytest.mark.parametrize(
"payload",
[
[{"tag_name": "v2.0.0-beta", "prerelease": True, "draft": False}],
[{"tag_name": "v2.0.0", "prerelease": False, "draft": True}],
],
)
@pytest.mark.asyncio
async def test_ignores_draft_releases_and_prereleases(
payload: dict[str, object],
) -> None:
def handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=httpx.codes.OK, json=payload)
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
update = await notifier.fetch_update()
assert update is None
@pytest.mark.parametrize(
("handler", "expected_cause", "expected_custom_message"),
[
(
lambda _: httpx.Response(status_code=httpx.codes.NOT_FOUND),
VersionUpdateGatewayCause.NOT_FOUND,
"Unable to fetch the GitHub releases. Did you export a GITHUB_TOKEN environment variable?",
),
(
lambda _: httpx.Response(
status_code=httpx.codes.FORBIDDEN,
headers={"X-RateLimit-Remaining": "0"},
),
VersionUpdateGatewayCause.TOO_MANY_REQUESTS,
None,
),
(
lambda _: httpx.Response(status_code=httpx.codes.TOO_MANY_REQUESTS),
VersionUpdateGatewayCause.TOO_MANY_REQUESTS,
None,
),
(
lambda _: httpx.Response(status_code=httpx.codes.FORBIDDEN),
VersionUpdateGatewayCause.FORBIDDEN,
None,
),
(
lambda _: httpx.Response(status_code=httpx.codes.INTERNAL_SERVER_ERROR),
VersionUpdateGatewayCause.ERROR_RESPONSE,
None,
),
(
lambda _: httpx.Response(status_code=httpx.codes.OK, text="not json"),
VersionUpdateGatewayCause.INVALID_RESPONSE,
None,
),
(_raise_connect_timeout, VersionUpdateGatewayCause.REQUEST_FAILED, None),
],
ids=[
"not_found",
"rate_limit_header",
"rate_limit_status",
"forbidden",
"error_response",
"invalid_json",
"request_error",
],
)
@pytest.mark.asyncio
async def test_retrieves_nothing_when_fetching_update_fails(
handler: Handler,
expected_cause: VersionUpdateGatewayCause,
expected_custom_message: str | None,
) -> None:
transport = httpx.MockTransport(handler)
async with httpx.AsyncClient(
transport=transport, base_url=GITHUB_API_URL
) as client:
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
with pytest.raises(VersionUpdateGatewayError) as excinfo:
await notifier.fetch_update()
assert excinfo.value.cause == expected_cause
if expected_custom_message is not None:
assert str(excinfo.value) == expected_custom_message

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
import asyncio
from typing import Protocol
import pytest
from textual.app import Notification
from vibe.cli.textual_ui.app import VibeApp
from vibe.cli.update_notifier.fake_version_update_gateway import (
FakeVersionUpdateGateway,
)
from vibe.cli.update_notifier.version_update_gateway import (
VersionUpdate,
VersionUpdateGatewayCause,
VersionUpdateGatewayError,
)
from vibe.core.config import SessionLoggingConfig, VibeConfig
async def _wait_for_notification(
app: VibeApp, pilot, *, timeout: float = 1.0, interval: float = 0.05
) -> Notification:
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while loop.time() < deadline:
notifications = list(app._notifications)
if notifications:
return notifications[-1]
await pilot.pause(interval)
pytest.fail("Notification not displayed")
async def _assert_no_notifications(
app: VibeApp, pilot, *, timeout: float = 1.0, interval: float = 0.05
) -> None:
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while loop.time() < deadline:
if app._notifications:
pytest.fail("Notification unexpectedly displayed")
await pilot.pause(interval)
assert not app._notifications
@pytest.fixture
def vibe_config_with_update_checks_enabled() -> VibeConfig:
return VibeConfig(
session_logging=SessionLoggingConfig(enabled=False), enable_update_checks=True
)
class VibeAppFactory(Protocol):
def __call__(
self,
*,
notifier: FakeVersionUpdateGateway,
config: VibeConfig | None = None,
auto_approve: bool = False,
current_version: str = "0.1.0",
) -> VibeApp: ...
@pytest.fixture
def make_vibe_app(vibe_config_with_update_checks_enabled: VibeConfig) -> VibeAppFactory:
def _make_app(
*,
notifier: FakeVersionUpdateGateway,
config: VibeConfig | None = None,
auto_approve: bool = False,
current_version: str = "0.1.0",
) -> VibeApp:
return VibeApp(
config=config or vibe_config_with_update_checks_enabled,
auto_approve=auto_approve,
version_update_notifier=notifier,
current_version=current_version,
)
return _make_app
@pytest.mark.asyncio
async def test_ui_displays_update_notification(make_vibe_app: VibeAppFactory) -> None:
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
app = make_vibe_app(notifier=notifier)
async with app.run_test() as pilot:
notification = await _wait_for_notification(app, pilot, timeout=0.3)
assert notification.severity == "information"
assert notification.title == "Update available"
assert (
notification.message
== '0.1.0 => 0.2.0\nRun "uv tool upgrade mistral-vibe" to update'
)
@pytest.mark.asyncio
async def test_ui_does_not_display_update_notification_when_not_available(
make_vibe_app: VibeAppFactory,
) -> None:
notifier = FakeVersionUpdateGateway(update=None)
app = make_vibe_app(notifier=notifier)
async with app.run_test() as pilot:
await _assert_no_notifications(app, pilot, timeout=0.3)
assert notifier.fetch_update_calls == 1
@pytest.mark.asyncio
async def test_ui_displays_warning_toast_when_check_fails(
make_vibe_app: VibeAppFactory,
) -> None:
notifier = FakeVersionUpdateGateway(
error=VersionUpdateGatewayError(cause=VersionUpdateGatewayCause.FORBIDDEN)
)
app = make_vibe_app(notifier=notifier)
async with app.run_test() as pilot:
await pilot.pause(0.3)
notifications = list(app._notifications)
assert notifications
warning = notifications[-1]
assert warning.severity == "warning"
assert "forbidden" in warning.message.lower()
@pytest.mark.asyncio
async def test_ui_does_not_invoke_gateway_nor_show_error_notification_when_update_checks_are_disabled(
vibe_config_with_update_checks_enabled: VibeConfig, make_vibe_app: VibeAppFactory
) -> None:
config = vibe_config_with_update_checks_enabled
config.enable_update_checks = False
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
app = make_vibe_app(notifier=notifier, config=config)
async with app.run_test() as pilot:
await _assert_no_notifications(app, pilot, timeout=0.3)
assert notifier.fetch_update_calls == 0
@pytest.mark.asyncio
async def test_ui_does_not_invoke_gateway_nor_show_update_notification_when_update_checks_are_disabled(
vibe_config_with_update_checks_enabled: VibeConfig, make_vibe_app: VibeAppFactory
) -> None:
config = vibe_config_with_update_checks_enabled
config.enable_update_checks = False
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
app = make_vibe_app(notifier=notifier, config=config)
async with app.run_test() as pilot:
await _assert_no_notifications(app, pilot, timeout=0.3)
assert notifier.fetch_update_calls == 0

View File

@@ -0,0 +1,146 @@
from __future__ import annotations
import pytest
from vibe.cli.update_notifier.fake_version_update_gateway import (
FakeVersionUpdateGateway,
)
from vibe.cli.update_notifier.version_update import (
VersionUpdateError,
is_version_update_available,
)
from vibe.cli.update_notifier.version_update_gateway import (
VersionUpdate,
VersionUpdateGatewayCause,
VersionUpdateGatewayError,
)
@pytest.mark.asyncio
async def test_retrieves_the_latest_version_update_when_available() -> None:
latest_update = "1.0.3"
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version=latest_update)
)
update = await is_version_update_available(
version_update_notifier, current_version="1.0.0"
)
assert update is not None
assert update.latest_version == latest_update
@pytest.mark.asyncio
async def test_retrieves_nothing_when_the_current_version_is_the_latest() -> None:
current_version = "1.0.0"
latest_version = "1.0.0"
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version=latest_version)
)
update = await is_version_update_available(
version_update_notifier, current_version=current_version
)
assert update is None
@pytest.mark.asyncio
async def test_retrieves_nothing_when_the_current_version_is_greater_than_the_latest() -> (
None
):
current_version = "0.2.0"
latest_version = "0.1.2"
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version=latest_version)
)
update = await is_version_update_available(
version_update_notifier, current_version=current_version
)
assert update is None
@pytest.mark.asyncio
async def test_retrieves_nothing_when_no_version_is_available() -> None:
version_update_notifier = FakeVersionUpdateGateway(update=None)
update = await is_version_update_available(
version_update_notifier, current_version="1.0.0"
)
assert update is None
@pytest.mark.asyncio
async def test_retrieves_nothing_when_latest_version_is_invalid() -> None:
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version="invalid-version")
)
update = await is_version_update_available(
version_update_notifier, current_version="1.0.0"
)
assert update is None
@pytest.mark.asyncio
async def test_replaces_hyphens_with_plus_signs_in_latest_version_to_conform_with_PEP_440() -> (
None
):
version_update_notifier = FakeVersionUpdateGateway(
# if we were not replacing hyphens with plus signs, this should fail for PEP 440
update=VersionUpdate(latest_version="1.6.1-jetbrains")
)
update = await is_version_update_available(
version_update_notifier, current_version="1.0.0"
)
assert update is not None
assert update.latest_version == "1.6.1-jetbrains"
@pytest.mark.asyncio
async def test_retrieves_nothing_when_current_version_is_invalid() -> None:
version_update_notifier = FakeVersionUpdateGateway(
update=VersionUpdate(latest_version="1.0.1")
)
update = await is_version_update_available(
version_update_notifier, current_version="invalid-version"
)
assert update is None
@pytest.mark.parametrize(
("cause", "expected_message_substring"),
[
(VersionUpdateGatewayCause.TOO_MANY_REQUESTS, "Rate limit exceeded"),
(VersionUpdateGatewayCause.INVALID_RESPONSE, "invalid response"),
(
VersionUpdateGatewayCause.NOT_FOUND,
"Unable to fetch the releases. Please check your permissions.",
),
(VersionUpdateGatewayCause.ERROR_RESPONSE, "Unexpected response"),
(VersionUpdateGatewayCause.REQUEST_FAILED, "Network error"),
],
)
@pytest.mark.asyncio
async def test_raises_version_update_error(
cause: VersionUpdateGatewayCause, expected_message_substring: str
) -> None:
version_update_notifier = FakeVersionUpdateGateway(
error=VersionUpdateGatewayError(cause=cause)
)
with pytest.raises(VersionUpdateError) as excinfo:
await is_version_update_available(
version_update_notifier, current_version="1.0.0"
)
assert expected_message_substring in str(excinfo.value)

1715
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

45
vibe-acp.spec Normal file
View File

@@ -0,0 +1,45 @@
# -*- mode: python ; coding: utf-8 -*-
a = Analysis(
['vibe/acp/entrypoint.py'],
pathex=[],
binaries=[],
datas=[
# By default, pyinstaller doesn't include the .md files
('vibe/core/prompts/*.md', 'vibe/core/prompts'),
('vibe/core/tools/builtins/prompts/*.md', 'vibe/core/tools/builtins/prompts'),
# This is necessary because tools are dynamically called in vibe, meaning there is no static reference to those files
('vibe/core/tools/builtins/*.py', 'vibe/core/tools/builtins'),
('vibe/acp/tools/builtins/*.py', 'vibe/acp/tools/builtins'),
],
hiddenimports=[],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
noarchive=False,
optimize=0,
)
pyz = PYZ(a.pure)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.datas,
[],
name='vibe-acp',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

5
vibe/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from __future__ import annotations
from pathlib import Path
VIBE_ROOT = Path(__file__).parent

0
vibe/acp/__init__.py Normal file
View File

441
vibe/acp/acp_agent.py Normal file
View File

@@ -0,0 +1,441 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncGenerator
from pathlib import Path
import sys
from typing import Any, cast, override
from acp import (
PROTOCOL_VERSION,
Agent as AcpAgent,
AgentSideConnection,
AuthenticateRequest,
CancelNotification,
InitializeRequest,
InitializeResponse,
LoadSessionRequest,
NewSessionRequest,
NewSessionResponse,
PromptRequest,
PromptResponse,
RequestError,
RequestPermissionRequest,
SessionNotification,
SetSessionModelRequest,
SetSessionModelResponse,
SetSessionModeRequest,
SetSessionModeResponse,
stdio_streams,
)
from acp.helpers import ContentBlock, SessionUpdate
from acp.schema import (
AgentCapabilities,
AgentMessageChunk,
AllowedOutcome,
AuthenticateResponse,
AuthMethod,
Implementation,
ModelInfo,
PromptCapabilities,
SessionModelState,
SessionModeState,
TextContentBlock,
TextResourceContents,
ToolCall,
)
from pydantic import BaseModel, ConfigDict
from vibe import VIBE_ROOT
from vibe.acp.tools.base import BaseAcpTool
from vibe.acp.tools.session_update import (
tool_call_session_update,
tool_result_session_update,
)
from vibe.acp.utils import TOOL_OPTIONS, ToolOption, VibeSessionMode
from vibe.core import __version__
from vibe.core.agent import Agent as VibeAgent
from vibe.core.autocompletion.path_prompt_adapter import render_path_prompt
from vibe.core.config import MissingAPIKeyError, VibeConfig, load_api_keys_from_env
from vibe.core.types import (
AssistantEvent,
AsyncApprovalCallback,
ToolCallEvent,
ToolResultEvent,
)
from vibe.core.utils import CancellationReason, get_user_cancellation_message
class AcpSession(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
id: str
agent: VibeAgent
mode_id: VibeSessionMode = VibeSessionMode.APPROVAL_REQUIRED
task: asyncio.Task[None] | None = None
class VibeAcpAgent(AcpAgent):
def __init__(self, connection: AgentSideConnection) -> None:
self.sessions: dict[str, AcpSession] = {}
self.connection = connection
self.client_capabilities = None
@override
async def initialize(self, params: InitializeRequest) -> InitializeResponse:
self.client_capabilities = params.clientCapabilities
# The ACP Agent process can be launched in 3 different ways, depending on installation
# - dev mode: `uv run vibe-acp`, ran from the project root
# - uv tool install: `vibe-acp`, similar to dev mode, but uv takes care of path resolution
# - bundled binary: `./vibe-acp` from binary location
# The 2 first modes are working similarly, under the hood uv runs `/some/python /my/entrypoint.py``
# The last mode is quite different as our bundler also includes the python install.
# So sys.executable is already /path/to/binary/vibe-acp.
# For this reason, we make a distinction in the way we call the setup command
command = sys.executable
if "python" not in Path(command).name:
# It's the case for bundled binaries, we don't need any other arguments
args = ["--setup"]
else:
script_name = sys.argv[0]
args = [script_name, "--setup"]
auth_methods = [
AuthMethod(
id="vibe-setup",
name="Register your API Key",
description="Register your API Key inside Mistral Vibe",
field_meta={
"terminal-auth": {
"command": command,
"args": args,
"label": "Mistral Vibe Setup",
}
},
)
]
response = InitializeResponse(
agentCapabilities=AgentCapabilities(
loadSession=False,
promptCapabilities=PromptCapabilities(
audio=False, embeddedContext=True, image=False
),
),
protocolVersion=PROTOCOL_VERSION,
agentInfo=Implementation(
name="@mistralai/mistral-vibe",
title="Mistral Vibe",
version=__version__,
),
authMethods=auth_methods,
)
return response
@override
async def authenticate(
self, params: AuthenticateRequest
) -> AuthenticateResponse | None:
raise NotImplementedError("Not implemented yet")
@override
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
capability_disabled_tools = self._get_disabled_tools_from_capabilities()
load_api_keys_from_env()
try:
config = VibeConfig.load(
workdir=Path(params.cwd),
tool_paths=[str(VIBE_ROOT / "acp" / "tools" / "builtins")],
disabled_tools=capability_disabled_tools,
)
except MissingAPIKeyError as e:
raise RequestError.auth_required({
"message": "You must be authenticated before creating a new session"
}) from e
agent = VibeAgent(config=config, auto_approve=False, enable_streaming=True)
# NOTE: For now, we pin session.id to agent.session_id right after init time.
# We should just use agent.session_id everywhere, but it can still change during
# session lifetime (e.g. agent.compact is called).
# We should refactor agent.session_id to make it immutable in ACP context.
session = AcpSession(id=agent.session_id, agent=agent)
self.sessions[session.id] = session
if not agent.auto_approve:
agent.set_approval_callback(
self._create_approval_callback(agent.session_id)
)
response = NewSessionResponse(
sessionId=agent.session_id,
models=SessionModelState(
currentModelId=agent.config.active_model,
availableModels=[
ModelInfo(modelId=model.alias, name=model.alias)
for model in agent.config.models
],
),
modes=SessionModeState(
currentModeId=session.mode_id,
availableModes=VibeSessionMode.get_all_acp_session_modes(),
),
)
return response
def _get_disabled_tools_from_capabilities(self) -> list[str]:
if not self.client_capabilities:
return []
disabled: list[str] = []
if not self.client_capabilities.terminal:
disabled.append("bash")
if fs := self.client_capabilities.fs:
if not fs.readTextFile:
disabled.append("read_file")
if not fs.writeTextFile:
disabled.append("write_file")
disabled.append("search_replace")
return disabled
def _create_approval_callback(self, session_id: str) -> AsyncApprovalCallback:
async def approval_callback(
tool_name: str, args: dict[str, Any], tool_call_id: str
) -> tuple[str, str | None]:
# Create the tool call update
tool_call = ToolCall(toolCallId=tool_call_id)
# Request permission from the user
request = RequestPermissionRequest(
sessionId=session_id, toolCall=tool_call, options=TOOL_OPTIONS
)
response = await self.connection.requestPermission(request)
# Parse the response using isinstance for proper type narrowing
if response.outcome.outcome == "selected":
outcome = cast(AllowedOutcome, response.outcome)
return self._handle_permission_selection(outcome.optionId)
else:
return (
"n",
str(
get_user_cancellation_message(
CancellationReason.OPERATION_CANCELLED
)
),
)
return approval_callback
@staticmethod
def _handle_permission_selection(option_id: str) -> tuple[str, str | None]:
match option_id:
case ToolOption.ALLOW_ONCE:
return ("y", None)
case ToolOption.ALLOW_ALWAYS:
return ("a", None)
case ToolOption.REJECT_ONCE:
return ("n", "User rejected the tool call, provide an alternative plan")
case _:
return ("n", f"Unknown option: {option_id}")
def _get_session(self, session_id: str) -> AcpSession:
if session_id not in self.sessions:
raise RequestError.invalid_params({"session": "Not found"})
return self.sessions[session_id]
@override
async def loadSession(self, params: LoadSessionRequest) -> None:
raise NotImplementedError()
@override
async def setSessionMode(
self, params: SetSessionModeRequest
) -> SetSessionModeResponse | None:
session = self._get_session(params.sessionId)
if not VibeSessionMode.is_valid(params.modeId):
return None
session.mode_id = VibeSessionMode(params.modeId)
session.agent.auto_approve = params.modeId == VibeSessionMode.AUTO_APPROVE
return SetSessionModeResponse()
@override
async def setSessionModel(
self, params: SetSessionModelRequest
) -> SetSessionModelResponse | None:
session = self._get_session(params.sessionId)
model_aliases = [model.alias for model in session.agent.config.models]
if params.modelId not in model_aliases:
return None
VibeConfig.save_updates({"active_model": params.modelId})
new_config = VibeConfig.load(
workdir=session.agent.config.workdir,
tool_paths=session.agent.config.tool_paths,
disabled_tools=self._get_disabled_tools_from_capabilities(),
)
await session.agent.reload_with_initial_messages(config=new_config)
return SetSessionModelResponse()
@override
async def prompt(self, params: PromptRequest) -> PromptResponse:
session = self._get_session(params.sessionId)
if session.task is not None:
raise RuntimeError(
"Concurrent prompts are not supported yet, wait for agent to finish"
)
text_prompt = self._build_text_prompt(params.prompt)
async def agent_task() -> None:
async for update in self._run_agent_loop(session, text_prompt):
await self.connection.sessionUpdate(
SessionNotification(sessionId=session.id, update=update)
)
try:
session.task = asyncio.create_task(agent_task())
await session.task
except asyncio.CancelledError:
return PromptResponse(stopReason="cancelled")
except Exception as e:
await self.connection.sessionUpdate(
SessionNotification(
sessionId=params.sessionId,
update=AgentMessageChunk(
sessionUpdate="agent_message_chunk",
content=TextContentBlock(type="text", text=f"Error: {e!s}"),
),
)
)
return PromptResponse(stopReason="refusal")
finally:
session.task = None
return PromptResponse(stopReason="end_turn")
def _build_text_prompt(self, acp_prompt: list[ContentBlock]) -> str:
text_prompt = ""
for block in acp_prompt:
separator = "\n\n" if text_prompt else ""
match block.type:
# NOTE: ACP supports annotations, but we don't use them here yet.
case "text":
text_prompt = f"{text_prompt}{separator}{block.text}"
case "resource":
block_content = (
block.resource.text
if isinstance(block.resource, TextResourceContents)
else block.resource.blob
)
fields = {"path": block.resource.uri, "content": block_content}
parts = [
f"{k}: {v}"
for k, v in fields.items()
if v is not None and (v or isinstance(v, (int, float)))
]
block_prompt = "\n".join(parts)
text_prompt = f"{text_prompt}{separator}{block_prompt}"
case "resource_link":
# NOTE: we currently keep more information than just the URI
# making it more detailed than the output of the read_file tool.
# This is OK, but might be worth testing how it affect performance.
fields = {
"uri": block.uri,
"name": block.name,
"title": block.title,
"description": block.description,
"mimeType": block.mimeType,
"size": block.size,
}
parts = [
f"{k}: {v}"
for k, v in fields.items()
if v is not None and (v or isinstance(v, (int, float)))
]
block_prompt = "\n".join(parts)
text_prompt = f"{text_prompt}{separator}{block_prompt}"
case _:
raise ValueError(f"Unsupported content block type: {block.type}")
return text_prompt
async def _run_agent_loop(
self, session: AcpSession, prompt: str
) -> AsyncGenerator[SessionUpdate]:
rendered_prompt = render_path_prompt(
prompt, base_dir=session.agent.config.effective_workdir
)
async for event in session.agent.act(rendered_prompt):
if isinstance(event, AssistantEvent):
yield AgentMessageChunk(
sessionUpdate="agent_message_chunk",
content=TextContentBlock(type="text", text=event.content),
)
elif isinstance(event, ToolCallEvent):
if issubclass(event.tool_class, BaseAcpTool):
event.tool_class.update_tool_state(
tool_manager=session.agent.tool_manager,
connection=self.connection,
session_id=session.id,
tool_call_id=event.tool_call_id,
)
session_update = tool_call_session_update(event)
if session_update:
yield session_update
elif isinstance(event, ToolResultEvent):
session_update = tool_result_session_update(event)
if session_update:
yield session_update
@override
async def cancel(self, params: CancelNotification) -> None:
session = self._get_session(params.sessionId)
if session.task and not session.task.done():
session.task.cancel()
session.task = None
@override
async def extMethod(self, method: str, params: dict) -> dict:
raise NotImplementedError()
@override
async def extNotification(self, method: str, params: dict) -> None:
raise NotImplementedError()
async def _run_acp_server() -> None:
reader, writer = await stdio_streams()
AgentSideConnection(lambda connection: VibeAcpAgent(connection), writer, reader)
await asyncio.Event().wait()
def run_acp_server() -> None:
try:
asyncio.run(_run_acp_server())
except KeyboardInterrupt:
# This is expected when the server is terminated
pass
except Exception as e:
# Log any unexpected errors
print(f"ACP Agent Server error: {e}", file=sys.stderr)
raise

37
vibe/acp/entrypoint.py Normal file
View File

@@ -0,0 +1,37 @@
from __future__ import annotations
import argparse
from dataclasses import dataclass
import sys
from vibe.acp.acp_agent import run_acp_server
from vibe.setup.onboarding import run_onboarding
# Configure line buffering for subprocess communication
sys.stdout.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
sys.stderr.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
sys.stdin.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
@dataclass
class Arguments:
setup: bool
def parse_arguments() -> Arguments:
parser = argparse.ArgumentParser(description="Run Mistral Vibe in ACP mode")
parser.add_argument("--setup", action="store_true", help="Setup API key and exit")
args = parser.parse_args()
return Arguments(setup=args.setup)
def main() -> None:
args = parse_arguments()
if args.setup:
run_onboarding()
sys.exit(0)
run_acp_server()
if __name__ == "__main__":
main()

View File

100
vibe/acp/tools/base.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Protocol, cast, runtime_checkable
from acp import AgentSideConnection, SessionNotification
from acp.helpers import SessionUpdate, ToolCallContentVariant
from acp.schema import ToolCallProgress
from pydantic import Field
from vibe.core.tools.base import BaseTool, ToolError
from vibe.core.tools.manager import ToolManager
from vibe.core.types import ToolCallEvent, ToolResultEvent
from vibe.core.utils import logger
@runtime_checkable
class ToolCallSessionUpdateProtocol(Protocol):
@classmethod
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None: ...
@runtime_checkable
class ToolResultSessionUpdateProtocol(Protocol):
@classmethod
def tool_result_session_update(
cls, event: ToolResultEvent
) -> SessionUpdate | None: ...
class AcpToolState:
connection: AgentSideConnection | None = Field(
default=None, description="ACP agent-side connection"
)
session_id: str | None = Field(default=None, description="Current ACP session ID")
tool_call_id: str | None = Field(
default=None, description="Current ACP tool call ID"
)
class BaseAcpTool[ToolState: AcpToolState](BaseTool):
state: ToolState
@classmethod
def get_tool_instance(
cls, tool_name: str, tool_manager: ToolManager
) -> BaseAcpTool[AcpToolState]:
return cast(BaseAcpTool[AcpToolState], tool_manager.get(tool_name))
@classmethod
def update_tool_state(
cls,
*,
tool_manager: ToolManager,
connection: AgentSideConnection | None,
session_id: str | None,
tool_call_id: str | None,
) -> None:
tool_instance = cls.get_tool_instance(cls.get_name(), tool_manager)
tool_instance.state.connection = connection
tool_instance.state.session_id = session_id
tool_instance.state.tool_call_id = tool_call_id
@classmethod
@abstractmethod
def _get_tool_state_class(cls) -> type[ToolState]: ...
def _load_state(self) -> tuple[AgentSideConnection, str, str | None]:
if self.state.connection is None:
raise ToolError(
"Connection not available in tool state. This tool can only be used within an ACP session."
)
if self.state.session_id is None:
raise ToolError(
"Session ID not available in tool state. This tool can only be used within an ACP session."
)
return self.state.connection, self.state.session_id, self.state.tool_call_id
async def _send_in_progress_session_update(
self, content: list[ToolCallContentVariant] | None = None
) -> None:
connection, session_id, tool_call_id = self._load_state()
if tool_call_id is None:
return
try:
await connection.sessionUpdate(
SessionNotification(
sessionId=session_id,
update=ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=tool_call_id,
status="in_progress",
content=content,
),
)
)
except Exception as e:
logger.error(f"Failed to update session: {e!r}")

View File

@@ -0,0 +1,144 @@
from __future__ import annotations
import asyncio
import shlex
from acp import CreateTerminalRequest, TerminalHandle
from acp.schema import (
EnvVariable,
TerminalToolCallContent,
ToolCallProgress,
ToolCallStart,
WaitForTerminalExitResponse,
)
from vibe import VIBE_ROOT
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
from vibe.core.tools.base import BaseToolState, ToolError
from vibe.core.tools.builtins.bash import Bash as CoreBashTool, BashArgs, BashResult
from vibe.core.types import ToolCallEvent, ToolResultEvent
from vibe.core.utils import logger
class AcpBashState(BaseToolState, AcpToolState):
pass
class Bash(CoreBashTool, BaseAcpTool[AcpBashState]):
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "bash.md"
state: AcpBashState
@classmethod
def _get_tool_state_class(cls) -> type[AcpBashState]:
return AcpBashState
async def run(self, args: BashArgs) -> BashResult:
connection, session_id, _ = self._load_state()
timeout = args.timeout or self.config.default_timeout
max_bytes = self.config.max_output_bytes
env, command, cmd_args = self._parse_command(args.command)
create_request = CreateTerminalRequest(
sessionId=session_id,
command=command,
args=cmd_args,
env=env,
cwd=str(self.config.effective_workdir),
outputByteLimit=max_bytes,
)
try:
terminal_handle = await connection.createTerminal(create_request)
except Exception as e:
raise ToolError(f"Failed to create terminal: {e!r}") from e
await self._send_in_progress_session_update([
TerminalToolCallContent(type="terminal", terminalId=terminal_handle.id)
])
try:
exit_response = await self._wait_for_terminal_exit(
terminal_handle, timeout, args.command
)
output_response = await terminal_handle.current_output()
return self._build_result(
command=args.command,
stdout=output_response.output,
stderr="",
returncode=exit_response.exitCode or 0,
)
finally:
try:
await terminal_handle.release()
except Exception as e:
logger.error(f"Failed to release terminal: {e!r}")
def _parse_command(
self, command_str: str
) -> tuple[list[EnvVariable], str, list[str]]:
parts = shlex.split(command_str)
env: list[EnvVariable] = []
command: str = ""
args: list[str] = []
for part in parts:
if "=" in part and not command:
key, value = part.split("=", 1)
env.append(EnvVariable(name=key, value=value))
elif not command:
command = part
else:
args.append(part)
return env, command, args
@classmethod
def get_summary(cls, args: BashArgs) -> str:
summary = f"{args.command}"
if args.timeout:
summary += f" (timeout {args.timeout}s)"
return summary
async def _wait_for_terminal_exit(
self, terminal_handle: TerminalHandle, timeout: int, command: str
) -> WaitForTerminalExitResponse:
try:
return await asyncio.wait_for(
terminal_handle.wait_for_exit(), timeout=timeout
)
except TimeoutError:
try:
await terminal_handle.kill()
except Exception as e:
logger.error(f"Failed to kill terminal: {e!r}")
raise self._build_timeout_error(command, timeout)
@classmethod
def tool_call_session_update(cls, event: ToolCallEvent) -> ToolCallStart:
if not isinstance(event.args, BashArgs):
raise ValueError(f"Unexpected tool args: {event.args}")
return ToolCallStart(
sessionUpdate="tool_call",
title=Bash.get_summary(event.args),
content=None,
toolCallId=event.tool_call_id,
kind="execute",
rawInput=event.args.model_dump_json(),
)
@classmethod
def tool_result_session_update(
cls, event: ToolResultEvent
) -> ToolCallProgress | None:
return ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=event.tool_call_id,
status="failed" if event.error else "completed",
)

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from pathlib import Path
from acp import ReadTextFileRequest
from vibe import VIBE_ROOT
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.read_file import (
ReadFile as CoreReadFileTool,
ReadFileArgs,
ReadFileResult,
ReadFileState,
_ReadResult,
)
ReadFileResult = ReadFileResult
class AcpReadFileState(ReadFileState, AcpToolState):
pass
class ReadFile(CoreReadFileTool, BaseAcpTool[AcpReadFileState]):
state: AcpReadFileState
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "read_file.md"
@classmethod
def _get_tool_state_class(cls) -> type[AcpReadFileState]:
return AcpReadFileState
async def _read_file(self, args: ReadFileArgs, file_path: Path) -> _ReadResult:
connection, session_id, _ = self._load_state()
line = args.offset + 1 if args.offset > 0 else None
limit = args.limit
read_request = ReadTextFileRequest(
sessionId=session_id, path=str(file_path), line=line, limit=limit
)
await self._send_in_progress_session_update()
try:
response = await connection.readTextFile(read_request)
except Exception as e:
raise ToolError(f"Error reading {file_path}: {e}") from e
content_lines = response.content.splitlines(keepends=True)
lines_read = len(content_lines)
bytes_read = sum(len(line.encode("utf-8")) for line in content_lines)
was_truncated = args.limit is not None and lines_read >= args.limit
return _ReadResult(
lines=content_lines, bytes_read=bytes_read, was_truncated=was_truncated
)

View File

@@ -0,0 +1,132 @@
from __future__ import annotations
from pathlib import Path
from acp import ReadTextFileRequest, WriteTextFileRequest
from acp.helpers import SessionUpdate
from acp.schema import (
FileEditToolCallContent,
ToolCallLocation,
ToolCallProgress,
ToolCallStart,
)
from vibe import VIBE_ROOT
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.search_replace import (
SearchReplace as CoreSearchReplaceTool,
SearchReplaceArgs,
SearchReplaceResult,
SearchReplaceState,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
class AcpSearchReplaceState(SearchReplaceState, AcpToolState):
file_backup_content: str | None = None
class SearchReplace(CoreSearchReplaceTool, BaseAcpTool[AcpSearchReplaceState]):
state: AcpSearchReplaceState
prompt_path = (
VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "search_replace.md"
)
@classmethod
def _get_tool_state_class(cls) -> type[AcpSearchReplaceState]:
return AcpSearchReplaceState
async def _read_file(self, file_path: Path) -> str:
connection, session_id, _ = self._load_state()
read_request = ReadTextFileRequest(sessionId=session_id, path=str(file_path))
await self._send_in_progress_session_update()
try:
response = await connection.readTextFile(read_request)
except Exception as e:
raise ToolError(f"Unexpected error reading {file_path}: {e}") from e
self.state.file_backup_content = response.content
return response.content
async def _backup_file(self, file_path: Path) -> None:
if self.state.file_backup_content is None:
return
await self._write_file(
file_path.with_suffix(file_path.suffix + ".bak"),
self.state.file_backup_content,
)
async def _write_file(self, file_path: Path, content: str) -> None:
connection, session_id, _ = self._load_state()
write_request = WriteTextFileRequest(
sessionId=session_id, path=str(file_path), content=content
)
try:
await connection.writeTextFile(write_request)
except Exception as e:
raise ToolError(f"Error writing {file_path}: {e}") from e
@classmethod
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
args = event.args
if not isinstance(args, SearchReplaceArgs):
return None
blocks = cls._parse_search_replace_blocks(args.content)
return ToolCallStart(
sessionUpdate="tool_call",
title=cls.get_call_display(event).summary,
toolCallId=event.tool_call_id,
kind="edit",
content=[
FileEditToolCallContent(
type="diff",
path=args.file_path,
oldText=block.search,
newText=block.replace,
)
for block in blocks
],
locations=[ToolCallLocation(path=args.file_path)],
rawInput=args.model_dump_json(),
)
@classmethod
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
if event.error:
return ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=event.tool_call_id,
status="failed",
)
result = event.result
if not isinstance(result, SearchReplaceResult):
return None
blocks = cls._parse_search_replace_blocks(result.content)
return ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=event.tool_call_id,
status="completed",
content=[
FileEditToolCallContent(
type="diff",
path=result.file,
oldText=block.search,
newText=block.replace,
)
for block in blocks
],
locations=[ToolCallLocation(path=result.file)],
rawOutput=result.model_dump_json(),
)

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
from typing import cast
from acp.helpers import SessionUpdate
from acp.schema import AgentPlanUpdate, PlanEntry, PlanEntryPriority, PlanEntryStatus
from vibe import VIBE_ROOT
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
from vibe.core.tools.builtins.todo import (
Todo as CoreTodoTool,
TodoArgs,
TodoPriority,
TodoResult,
TodoState,
TodoStatus,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
TodoArgs = TodoArgs
class AcpTodoState(TodoState, AcpToolState):
pass
class Todo(CoreTodoTool, BaseAcpTool[AcpTodoState]):
state: AcpTodoState
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "todo.md"
@classmethod
def _get_tool_state_class(cls) -> type[AcpTodoState]:
return AcpTodoState
@classmethod
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
return None
@classmethod
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
result = cast(TodoResult, event.result)
todos = [todo for todo in result.todos if todo.status != TodoStatus.CANCELLED]
matched_status: dict[TodoStatus, PlanEntryStatus] = {
TodoStatus.PENDING: "pending",
TodoStatus.IN_PROGRESS: "in_progress",
TodoStatus.COMPLETED: "completed",
}
matched_priority: dict[TodoPriority, PlanEntryPriority] = {
TodoPriority.LOW: "low",
TodoPriority.MEDIUM: "medium",
TodoPriority.HIGH: "high",
}
update = AgentPlanUpdate(
sessionUpdate="plan",
entries=[
PlanEntry(
content=todo.content,
status=matched_status[todo.status],
priority=matched_priority[todo.priority],
)
for todo in todos
],
)
return update

View File

@@ -0,0 +1,98 @@
from __future__ import annotations
from pathlib import Path
from acp import WriteTextFileRequest
from acp.helpers import SessionUpdate
from acp.schema import (
FileEditToolCallContent,
ToolCallLocation,
ToolCallProgress,
ToolCallStart,
)
from vibe import VIBE_ROOT
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
from vibe.core.tools.base import ToolError
from vibe.core.tools.builtins.write_file import (
WriteFile as CoreWriteFileTool,
WriteFileArgs,
WriteFileResult,
WriteFileState,
)
from vibe.core.types import ToolCallEvent, ToolResultEvent
class AcpWriteFileState(WriteFileState, AcpToolState):
pass
class WriteFile(CoreWriteFileTool, BaseAcpTool[AcpWriteFileState]):
state: AcpWriteFileState
prompt_path = (
VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "write_file.md"
)
@classmethod
def _get_tool_state_class(cls) -> type[AcpWriteFileState]:
return AcpWriteFileState
async def _write_file(self, args: WriteFileArgs, file_path: Path) -> None:
connection, session_id, _ = self._load_state()
write_request = WriteTextFileRequest(
sessionId=session_id, path=str(file_path), content=args.content
)
await self._send_in_progress_session_update()
try:
await connection.writeTextFile(write_request)
except Exception as e:
raise ToolError(f"Error writing {file_path}: {e}") from e
@classmethod
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
args = event.args
if not isinstance(args, WriteFileArgs):
return None
return ToolCallStart(
sessionUpdate="tool_call",
title=cls.get_call_display(event).summary,
toolCallId=event.tool_call_id,
kind="edit",
content=[
FileEditToolCallContent(
type="diff", path=args.path, oldText=None, newText=args.content
)
],
locations=[ToolCallLocation(path=args.path)],
rawInput=args.model_dump_json(),
)
@classmethod
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
if event.error:
return ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=event.tool_call_id,
status="failed",
)
result = event.result
if not isinstance(result, WriteFileResult):
return None
return ToolCallProgress(
sessionUpdate="tool_call_update",
toolCallId=event.tool_call_id,
status="completed",
content=[
FileEditToolCallContent(
type="diff", path=result.path, oldText=None, newText=result.content
)
],
locations=[ToolCallLocation(path=result.path)],
rawOutput=result.model_dump_json(),
)

Some files were not shown because too many files have changed in this diff Show More