mirror of
https://github.com/mistralai/mistral-vibe
synced 2026-04-25 17:14:55 +02:00
Initial commit
Co-Authored-By: Quentin Torroba <quentin.torroba@mistral.ai> Co-Authored-By: Laure Hugo <laure.hugo@mistral.ai> Co-Authored-By: Benjamin Trom <benjamin.trom@mistral.ai> Co-Authored-By: Mathias Gesbert <mathias.gesbert@ext.mistral.ai> Co-Authored-By: Michel Thomazo <michel.thomazo@mistral.ai> Co-Authored-By: Clément Drouin <clement.drouin@mistral.ai> Co-Authored-By: Vincent Guilloux <vincent.guilloux@mistral.ai> Co-Authored-By: Valentin Berard <val@mistral.ai> Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
12
.envrc
Normal file
12
.envrc
Normal file
@@ -0,0 +1,12 @@
|
||||
# shellcheck shell=bash
|
||||
if ! has nix_direnv_version || ! nix_direnv_version 3.1.0; then
|
||||
source_url "https://raw.githubusercontent.com/nix-community/nix-direnv/3.1.0/direnvrc" "sha256-yMJ2OVMzrFaDPn7q8nCBZFRYpL/f0RcHzhmw/i6btJM="
|
||||
fi
|
||||
|
||||
if command -v nix &>/dev/null; then
|
||||
use flake
|
||||
fi
|
||||
|
||||
if command -v pre-commit &>/dev/null; then
|
||||
pre-commit install
|
||||
fi
|
||||
7
.github/CODEOWNERS
vendored
Normal file
7
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
# CODEOWNERS
|
||||
|
||||
# Default owners for everything in the repo
|
||||
* @mistralai/mistral-vibe
|
||||
|
||||
# Owners for specific directories
|
||||
# Not needed for now, can be filled later
|
||||
86
.github/workflows/build-and-upload.yml
vendored
Normal file
86
.github/workflows/build-and-upload.yml
vendored
Normal file
@@ -0,0 +1,86 @@
|
||||
name: Build and upload
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
build-and-upload:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
# Linux
|
||||
- runner: ubuntu-22.04
|
||||
os: linux
|
||||
arch: x86_64
|
||||
# - runner: ubuntu-22.04-arm # ubuntu-22.04-arm, ubuntu-24.04-arm and windows-11-arm are not supported yet for private repositories
|
||||
# os: linux
|
||||
# arch: aarch64
|
||||
|
||||
# macOS
|
||||
- runner: macos-15-intel
|
||||
os: darwin
|
||||
arch: x86_64
|
||||
- runner: macos-14
|
||||
os: darwin
|
||||
arch: aarch64
|
||||
|
||||
# Windows
|
||||
- runner: windows-2022
|
||||
os: windows
|
||||
arch: x86_64
|
||||
# - runner: windows-11-arm # ubuntu-22.04-arm, ubuntu-24.04-arm and windows-11-arm are not supported yet for private repositories
|
||||
# os: windows
|
||||
# arch: aarch64
|
||||
runs-on: ${{ matrix.runner }}
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
|
||||
|
||||
- name: Install uv with caching
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Build with PyInstaller
|
||||
run: uv run pyinstaller vibe-acp.spec
|
||||
|
||||
- name: Get package version with uv (Unix)
|
||||
id: get_version_unix
|
||||
if: ${{ matrix.os != 'windows' }}
|
||||
run: python -c "import subprocess; version = subprocess.check_output(['uv', 'version']).decode().split()[1]; print(f'version={version}')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Get package version with uv (Windows)
|
||||
id: get_version_windows
|
||||
if: ${{ matrix.os == 'windows' }}
|
||||
shell: pwsh
|
||||
run: python -c "import subprocess; version = subprocess.check_output(['uv', 'version']).decode().split()[1]; print(f'version={version}')" >> $env:GITHUB_OUTPUT
|
||||
|
||||
- name: Upload binary as artifact (Unix)
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
|
||||
if: ${{ matrix.os != 'windows' }}
|
||||
with:
|
||||
name: vibe-acp-${{ matrix.os }}-${{ matrix.arch }}-${{ steps.get_version_unix.outputs.version }}
|
||||
path: dist/vibe-acp
|
||||
|
||||
- name: Upload binary as artifact (Windows)
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
|
||||
if: ${{ matrix.os == 'windows' }}
|
||||
with:
|
||||
name: vibe-acp-${{ matrix.os }}-${{ matrix.arch }}-${{ steps.get_version_windows.outputs.version }}
|
||||
path: dist/vibe-acp.exe
|
||||
126
.github/workflows/ci.yml
vendored
Normal file
126
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
PYTHON_VERSION: "3.12"
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
name: Pre-commit
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
|
||||
|
||||
- name: Install uv with caching
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Install pip (required by pre-commit)
|
||||
run: uv pip install pip
|
||||
|
||||
- name: Add virtual environment to PATH
|
||||
run: echo "$(pwd)/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Cache pre-commit
|
||||
uses: actions/cache@0057852bfaa89a56745cba8c7296529d2fc39830 # v4
|
||||
with:
|
||||
path: ~/.cache/pre-commit
|
||||
key: pre-commit-${{ hashFiles('.pre-commit-config.yaml') }}
|
||||
|
||||
- name: Run pre-commit
|
||||
run: uv run pre-commit run --all-files --show-diff-on-failure
|
||||
|
||||
tests:
|
||||
name: Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
|
||||
|
||||
- name: Install uv with caching
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Verify CLI can start
|
||||
run: |
|
||||
uv run vibe --help
|
||||
uv run vibe-acp --help
|
||||
|
||||
- name: Install ripgrep
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest --ignore tests/snapshots
|
||||
|
||||
snapshot-tests:
|
||||
name: Snapshot Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@08eba0b27e820071cde6df949e0beb9ba4906955 # v4
|
||||
|
||||
- name: Install uv with caching
|
||||
uses: astral-sh/setup-uv@1e862dfacbd1d6d858c55d9b792c756523627244 # v7
|
||||
with:
|
||||
version: "latest"
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Run snapshot tests
|
||||
id: snapshot-tests
|
||||
run: uv run pytest tests/snapshots
|
||||
continue-on-error: true
|
||||
|
||||
- name: Upload snapshot report
|
||||
if: steps.snapshot-tests.outcome == 'failure'
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
|
||||
with:
|
||||
name: snapshot-report
|
||||
path: snapshot_report.html
|
||||
if-no-files-found: warn
|
||||
retention-days: 3
|
||||
|
||||
- name: Fail job if snapshot tests failed
|
||||
if: steps.snapshot-tests.outcome == 'failure'
|
||||
run: |
|
||||
echo "Snapshot tests failed, failing job."
|
||||
exit 1
|
||||
44
.github/workflows/release.yml
vendored
Normal file
44
.github/workflows/release.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name: Release to Pipy
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
release:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/mistral-vibe
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --locked --dev
|
||||
|
||||
- name: Build package
|
||||
run: uv build
|
||||
|
||||
- name: Upload artifacts
|
||||
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5
|
||||
with:
|
||||
name: dist
|
||||
path: dist/
|
||||
|
||||
- name: Publish distribution to PyPI
|
||||
if: github.repository == 'mistralai/mistral-vibe'
|
||||
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0
|
||||
201
.gitignore
vendored
Normal file
201
.gitignore
vendored
Normal file
@@ -0,0 +1,201 @@
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.*cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# nix / direnv
|
||||
.direnv/
|
||||
result
|
||||
result-*
|
||||
|
||||
# Vibe runtime/session files; keep tools as part of repo when needed
|
||||
.vibe/*
|
||||
|
||||
# Tests run the agent in the playground, we don't need to keep the session files
|
||||
tests/playground/*
|
||||
.
|
||||
34
.pre-commit-config.yaml
Normal file
34
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
---
|
||||
repos:
|
||||
- repo: https://github.com/mpalmer/action-validator
|
||||
rev: v0.7.0
|
||||
hooks:
|
||||
- id: action-validator
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
exclude: tests/snapshots/.*\.svg$
|
||||
|
||||
- repo: https://github.com/fsouza/mirrors-pyright
|
||||
rev: v1.1.407
|
||||
hooks:
|
||||
- id: pyright
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.14.5
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
args: [--fix, --unsafe-fixes]
|
||||
- id: ruff-format
|
||||
args: [--check]
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--write-changes]
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.12
|
||||
2
.typos.toml
Normal file
2
.typos.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[default]
|
||||
extend-ignore-re = ["(?m)^.*(#|//)\\s*typos:disable-line$", "datas"]
|
||||
3
.vscode/extensions.json
vendored
Normal file
3
.vscode/extensions.json
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"recommendations": ["ms-python.python", "charliermarsh.ruff"]
|
||||
}
|
||||
59
.vscode/launch.json
vendored
Normal file
59
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,59 @@
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "ACP Server",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "vibe/acp/entrypoint.py",
|
||||
"args": ["--workdir", "${workspaceFolder}"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Tests",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": ["-v", "-s"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Single Test",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": ["-k", "${input:test_identifier}", "-vvv", "-s", "--no-header"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}",
|
||||
"env": {
|
||||
"PYTHONPATH": "${workspaceFolder}"
|
||||
},
|
||||
"stopOnEntry": false,
|
||||
"subProcess": true
|
||||
},
|
||||
{
|
||||
"name": "CLI",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "vibe/cli/entrypoint.py",
|
||||
"args": [],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
],
|
||||
"inputs": [
|
||||
{
|
||||
"id": "test_identifier",
|
||||
"description": "Enter the test identifier (file, class, or function)",
|
||||
"default": "TestInitialization",
|
||||
"type": "promptString"
|
||||
}
|
||||
]
|
||||
}
|
||||
26
.vscode/settings.json
vendored
Normal file
26
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.fixAll.ruff": "explicit",
|
||||
"source.organizeImports.ruff": "explicit"
|
||||
},
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true
|
||||
},
|
||||
"cursorpyright.analysis.typeCheckingMode": "strict",
|
||||
"editor.formatOnSave": true,
|
||||
"files.exclude": {
|
||||
".pytest_cache/**": true,
|
||||
".venv/**": true,
|
||||
"**/__pycache__": true,
|
||||
"dist/**": true,
|
||||
"build/**": true
|
||||
},
|
||||
"files.insertFinalNewline": true,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"python.analysis.typeCheckingMode": "strict",
|
||||
"python.testing.pytestArgs": ["tests"],
|
||||
"python.testing.pytestEnabled": true,
|
||||
"ruff.enable": true,
|
||||
"ruff.organizeImports": true
|
||||
}
|
||||
135
AGENTS.md
Normal file
135
AGENTS.md
Normal file
@@ -0,0 +1,135 @@
|
||||
# python312.rule
|
||||
# Rule for enforcing modern Python 3.12+ best practices.
|
||||
# Applies to all Python files (*.py) in the project.
|
||||
#
|
||||
# Guidelines covered:
|
||||
# - Use match-case syntax instead of if/elif/else for pattern matching.
|
||||
# - Use the walrus operator (:=) when it simplifies assignments and tests.
|
||||
# - Favor a "never nester" approach by avoiding deep nesting with early returns or guard clauses.
|
||||
# - Employ modern type hints using built-in generics (list, dict) and the union pipe (|) operator,
|
||||
# rather than deprecated types from the typing module (e.g., Optional, Union, Dict, List).
|
||||
# - Ensure code adheres to strong static typing practices compatible with static analyzers like pyright.
|
||||
# - Favor pathlib.Path methods for file system operations over older os.path functions.
|
||||
# - Write code in a declarative and minimalist style that clearly expresses its intent.
|
||||
# - Additional best practices including f-string formatting, comprehensions, context managers, and overall PEP 8 compliance.
|
||||
|
||||
description: "Modern Python 3.12+ best practices and style guidelines for coding."
|
||||
files: "**/*.py"
|
||||
|
||||
guidelines:
|
||||
- title: "Match-Case Syntax"
|
||||
description: >
|
||||
Prefer using the match-case construct over traditional if/elif/else chains when pattern matching
|
||||
is applicable. This leads to clearer, more concise, and more maintainable code.
|
||||
|
||||
- title: "Walrus Operator"
|
||||
description: >
|
||||
Utilize the walrus operator (:=) to streamline code where assignment and conditional testing can be combined.
|
||||
Use it judiciously when it improves readability and reduces redundancy.
|
||||
|
||||
- title: "Never Nester"
|
||||
description: >
|
||||
Aim to keep code flat by avoiding deep nesting. Use early returns, guard clauses, and refactoring to
|
||||
minimize nested structures, making your code more readable and maintainable.
|
||||
|
||||
- title: "Modern Type Hints"
|
||||
description: >
|
||||
Adopt modern type hinting by using built-in generics like list and dict, along with the pipe (|) operator
|
||||
for union types (e.g., int | None). Avoid older, deprecated constructs such as Optional, Union, Dict, and List
|
||||
from the typing module.
|
||||
|
||||
- title: "Strong Static Typing"
|
||||
description: >
|
||||
Write code with explicit and robust type annotations that are fully compatible with static type checkers
|
||||
like pyright. This ensures higher code reliability and easier maintenance.
|
||||
|
||||
- title: "Pydantic-First Parsing"
|
||||
description: >
|
||||
Prefer Pydantic v2's native validation over ad-hoc parsing. Use `model_validate`,
|
||||
`field_validator`, `from_attributes`, and field aliases to coerce external SDK/DTO objects.
|
||||
Avoid manual `getattr`/`hasattr` flows and custom constructors like `from_sdk` unless they are
|
||||
thin wrappers over `model_validate`. Keep normalization logic inside model validators so call sites
|
||||
remain declarative and typed.
|
||||
|
||||
- title: "Pathlib for File Operations"
|
||||
description: >
|
||||
Favor the use of pathlib.Path methods for file system operations. This approach offers a more
|
||||
readable, object-oriented way to handle file paths and enhances cross-platform compatibility,
|
||||
reducing reliance on legacy os.path functions.
|
||||
|
||||
- title: "Declarative and Minimalist Code"
|
||||
description: >
|
||||
Write code that is declarative—clearly expressing its intentions rather than focusing on implementation details.
|
||||
Strive to keep your code minimalist by removing unnecessary complexity and boilerplate. This approach improves
|
||||
readability, maintainability, and aligns with modern Python practices.
|
||||
|
||||
- title: "Additional Best Practices"
|
||||
description: >
|
||||
Embrace other modern Python idioms such as:
|
||||
- Using f-strings for string formatting.
|
||||
- Favoring comprehensions for building lists and dictionaries.
|
||||
- Employing context managers (with statements) for resource management.
|
||||
- Following PEP 8 guidelines to maintain overall code style consistency.
|
||||
|
||||
- title: "Exception Documentation"
|
||||
description: >
|
||||
Document exceptions accurately and minimally in docstrings:
|
||||
- Only document exceptions that are explicitly raised in the function implementation
|
||||
- Remove Raises entries for exceptions that are not directly raised
|
||||
- Include all possible exceptions from explicit raise statements
|
||||
- For public APIs, document exceptions from called functions if they are allowed to propagate
|
||||
- Avoid documenting built-in exceptions that are obvious (like TypeError from type hints)
|
||||
This ensures documentation stays accurate and maintainable, avoiding the common pitfall
|
||||
of listing every possible exception that could theoretically occur.
|
||||
|
||||
- title: "Modern Enum Usage"
|
||||
description: >
|
||||
Leverage Python's enum module effectively following modern practices:
|
||||
- Use StrEnum for string-based enums that need string comparison
|
||||
- Use IntEnum/IntFlag for performance-critical integer-based enums
|
||||
- Use auto() for automatic value assignment to maintain clean code
|
||||
- Always use UPPERCASE for enum members to avoid name clashes
|
||||
- Add methods to enums when behavior needs to be associated with values
|
||||
- Use @property for computed attributes rather than storing values
|
||||
- For type mixing, ensure mix-in types appear before Enum in base class sequence
|
||||
- Consider Flag/IntFlag for bit field operations
|
||||
- Use _generate_next_value_ for custom value generation
|
||||
- Implement __bool__ when enum boolean evaluation should depend on value
|
||||
This promotes type-safe constants, self-documenting code, and maintainable value sets.
|
||||
|
||||
- title: "No Inline Ignores"
|
||||
description: >
|
||||
Do not use inline suppressions like `# type: ignore[...]` or `# noqa[...]` in production code.
|
||||
Instead, fix types and lint warnings at the source by:
|
||||
- Refining signatures with generics (TypeVar), Protocols, or precise return types
|
||||
- Guarding with `isinstance` checks before attribute access
|
||||
- Using `typing.cast` when control flow guarantees the type
|
||||
- Extracting small helpers to create clearer, typed boundaries
|
||||
If a suppression is truly unavoidable at an external boundary, prefer a narrow, well-typed wrapper
|
||||
over in-line ignores, and document the rationale in code comments.
|
||||
|
||||
- title: "Pydantic Discriminated Unions"
|
||||
description: >
|
||||
When modeling variants with a discriminated union (e.g., on a `transport` field), do not narrow a
|
||||
field type in a subclass (e.g., overriding `transport: Literal['http']` with `Literal['streamable-http']`).
|
||||
This violates Liskov substitution and triggers type checker errors due to invariance of class attributes.
|
||||
Prefer sibling classes plus a shared mixin for common fields and helpers, and compose the union with
|
||||
`Annotated[Union[...], Field(discriminator='transport')]`.
|
||||
Example pattern:
|
||||
- Create a base with shared non-discriminator fields (e.g., `_MCPBase`).
|
||||
- Create a mixin with protocol-specific fields/methods (e.g., `_MCPHttpFields`), without a `transport`.
|
||||
- Define sibling final classes per variant (e.g., `MCPHttp`, `MCPStreamableHttp`, `MCPStdio`) that set
|
||||
`transport: Literal[...]` once in each final class.
|
||||
- Use `match` on the discriminator to narrow types at call sites.
|
||||
|
||||
- title: "Use uv for All Commands"
|
||||
description: >
|
||||
We use uv to manage our python environment. You should nevery try to run a bare python commands.
|
||||
Always run commands using `uv` instead of invoking `python` or `pip` directly.
|
||||
For example, use `uv add package` and `uv run script.py` rather than `pip install package` or `python script.py`.
|
||||
This practice helps avoid environment drift and leverages modern Python packaging best practices.
|
||||
Useful uv commands are:
|
||||
- uv add/remove <package> to manage dependencies
|
||||
- uv sync to install dependencies declared in pyproject.toml and uv.lock
|
||||
- uv run script.py to run a script within the uv environment
|
||||
- uv run pytest (or any other python tool) to run the tool within the uv environment
|
||||
10
CHANGELOG.md
Normal file
10
CHANGELOG.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
- Initial release
|
||||
168
CONTRIBUTING.md
Normal file
168
CONTRIBUTING.md
Normal file
@@ -0,0 +1,168 @@
|
||||
# Contributing to Mistral Vibe
|
||||
|
||||
Thank you for your interest in Mistral Vibe! We appreciate your enthusiasm and support.
|
||||
|
||||
## Current Status
|
||||
|
||||
**Mistral Vibe is in active development** — our team is iterating quickly and making lots of changes under the hood. Because of this pace, we may be slower than usual when reviewing PRs and issues.
|
||||
|
||||
**We especially encourage**:
|
||||
|
||||
- **Bug reports** – Help us uncover and squash issues
|
||||
- **Feedback & ideas** – Tell us what works, what doesn't, and what could be even better
|
||||
- **Documentation improvements** – Suggest clarity improvements or highlight missing pieces
|
||||
|
||||
## How to Provide Feedback
|
||||
|
||||
### Bug Reports
|
||||
|
||||
If you encounter a bug, please open an issue with the following information:
|
||||
|
||||
1. **Description**: A clear description of the bug
|
||||
2. **Steps to Reproduce**: Detailed steps to reproduce the issue
|
||||
3. **Expected Behavior**: What you expected to happen
|
||||
4. **Actual Behavior**: What actually happened
|
||||
5. **Environment**:
|
||||
- Python version
|
||||
- Operating system
|
||||
- Vibe version
|
||||
6. **Error Messages**: Any error messages or stack traces
|
||||
7. **Configuration**: Relevant parts of your `config.toml` (redact any sensitive information)
|
||||
|
||||
### Feature Requests and Feedback
|
||||
|
||||
We'd love to hear your ideas! When submitting feedback or feature requests:
|
||||
|
||||
1. **Clear Description**: Explain what you'd like to see or improve
|
||||
2. **Use Case**: Describe your use case and why this would be valuable
|
||||
3. **Alternatives**: If applicable, mention any alternatives you've considered
|
||||
|
||||
## Development Setup
|
||||
|
||||
This section is for developers who want to set up the repository for local development, even though we're not currently accepting contributions.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12 or higher
|
||||
- [uv](https://github.com/astral-sh/uv) - Modern Python package manager
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone <repository-url>
|
||||
cd mistral-vibe
|
||||
```
|
||||
|
||||
2. Install dependencies:
|
||||
|
||||
```bash
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
This will install both runtime and development dependencies.
|
||||
|
||||
3. (Optional) Install pre-commit hooks:
|
||||
|
||||
```bash
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
Pre-commit hooks will automatically run checks before each commit.
|
||||
|
||||
### Running Tests
|
||||
|
||||
Run all tests:
|
||||
|
||||
```bash
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run tests with verbose output:
|
||||
|
||||
```bash
|
||||
uv run pytest -v
|
||||
```
|
||||
|
||||
Run a specific test file:
|
||||
|
||||
```bash
|
||||
uv run pytest tests/test_agent_tool_call.py
|
||||
```
|
||||
|
||||
### Linting and Type Checking
|
||||
|
||||
#### Ruff (Linting and Formatting)
|
||||
|
||||
Check for linting issues (without fixing):
|
||||
|
||||
```bash
|
||||
uv run ruff check .
|
||||
```
|
||||
|
||||
Auto-fix linting issues:
|
||||
|
||||
```bash
|
||||
uv run ruff check --fix .
|
||||
```
|
||||
|
||||
Format code:
|
||||
|
||||
```bash
|
||||
uv run ruff format .
|
||||
```
|
||||
|
||||
Check formatting without modifying files (useful for CI):
|
||||
|
||||
```bash
|
||||
uv run ruff format --check .
|
||||
```
|
||||
|
||||
#### Pyright (Type Checking)
|
||||
|
||||
Run type checking:
|
||||
|
||||
```bash
|
||||
uv run pyright
|
||||
```
|
||||
|
||||
#### Pre-commit Hooks
|
||||
|
||||
Run all pre-commit hooks manually:
|
||||
|
||||
```bash
|
||||
uv run pre-commit run --all-files
|
||||
```
|
||||
|
||||
The pre-commit hooks include:
|
||||
|
||||
- Ruff (linting and formatting)
|
||||
- Pyright (type checking)
|
||||
- Typos (spell checking)
|
||||
- YAML/TOML validation
|
||||
- Action validator (for GitHub Actions)
|
||||
|
||||
### Code Style
|
||||
|
||||
- **Line length**: 88 characters (Black-compatible)
|
||||
- **Type hints**: Required for all functions and methods
|
||||
- **Docstrings**: Follow Google-style docstrings
|
||||
- **Formatting**: Use Ruff for both linting and formatting
|
||||
- **Type checking**: Use Pyright (configured in `pyproject.toml`)
|
||||
|
||||
See `pyproject.toml` for detailed configuration of Ruff and Pyright.
|
||||
|
||||
## Code Contributions
|
||||
|
||||
While we're not accepting code contributions at the moment, we may open up contributions in the future. When that happens, we'll update this document with:
|
||||
|
||||
- Pull request process
|
||||
- Contribution guidelines
|
||||
- Review process
|
||||
|
||||
## Questions?
|
||||
|
||||
If you have questions about using Mistral Vibe, please check the [README](README.md) first. For other inquiries, feel free to open a discussion or issue.
|
||||
|
||||
Thank you for helping make Mistral Vibe better! 🙏
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Mistral AI
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
308
README.md
Normal file
308
README.md
Normal file
@@ -0,0 +1,308 @@
|
||||
# Mistral Vibe
|
||||
|
||||
[](https://pypi.org/project/mistral-vibe)
|
||||
[](https://www.python.org/downloads/release/python-3120/)
|
||||
[](https://github.com/mistralai/mistral-vibe/actions/workflows/ci.yml)
|
||||
[](https://github.com/mistralai/mistral-vibe/blob/main/LICENSE)
|
||||
|
||||
```
|
||||
██████████████████░░
|
||||
██████████████████░░
|
||||
████ ██████ ████░░
|
||||
████ ██ ████░░
|
||||
████ ████░░
|
||||
████ ██ ██ ████░░
|
||||
██ ██ ██░░
|
||||
██████████████████░░
|
||||
██████████████████░░
|
||||
```
|
||||
|
||||
**Mistral's open-source CLI coding assistant.**
|
||||
|
||||
Mistral Vibe is a command-line coding assistant powered by Mistral's models. It provides a conversational interface to your codebase, allowing you to use natural language to explore, modify, and interact with your projects through a powerful set of tools.
|
||||
|
||||
> [!WARNING]
|
||||
> Mistral Vibe works on Windows, but we officially support and target UNIX environments.
|
||||
|
||||
## Installation
|
||||
|
||||
Vibe requires Python 3.12 or higher.
|
||||
|
||||
### One-line install (recommended)
|
||||
|
||||
```bash
|
||||
# On Linux and macOS
|
||||
curl -LsSf https://mistral.ai/vibe/install.sh | bash
|
||||
```
|
||||
|
||||
### Using uv
|
||||
|
||||
```bash
|
||||
uv tool install mistral-vibe
|
||||
```
|
||||
|
||||
### Using pip
|
||||
|
||||
```bash
|
||||
pip install mistral-vibe
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive Chat**: A conversational AI agent that understands your requests and breaks down complex tasks.
|
||||
- **Powerful Toolset**: A suite of tools for file manipulation, code searching, version control, and command execution, right from the chat prompt.
|
||||
- Read, write, and patch files (`read_file`, `write_file`, `search_replace`).
|
||||
- Execute shell commands in a stateful terminal (`bash`).
|
||||
- Recursively search code with `grep` (with `ripgrep` support).
|
||||
- Manage a `todo` list to track the agent's work.
|
||||
- **Project-Aware Context**: Vibe automatically scans your project's file structure and Git status to provide relevant context to the agent, improving its understanding of your codebase.
|
||||
- **Advanced CLI Experience**: Built with modern libraries for a smooth and efficient workflow.
|
||||
- Autocompletion for slash commands (`/`) and file paths (`@`).
|
||||
- Persistent command history.
|
||||
- Beautiful Themes.
|
||||
- **Highly Configurable**: Customize models, providers, tool permissions, and UI preferences through a simple `config.toml` file.
|
||||
- **Safety First**: Features tool execution approval.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Navigate to your project's root directory:
|
||||
|
||||
```bash
|
||||
cd /path/to/your/project
|
||||
```
|
||||
|
||||
2. Run Vibe:
|
||||
|
||||
```bash
|
||||
vibe
|
||||
```
|
||||
|
||||
3. If this is your first time running Vibe, it will:
|
||||
|
||||
- Create a default configuration file at `~/.vibe/config.toml`
|
||||
- Prompt you to enter your API key if it's not already configured
|
||||
- Save your API key to `~/.vibe/.env` for future use
|
||||
|
||||
4. Start interacting with the agent!
|
||||
|
||||
```
|
||||
> Can you find all instances of the word "TODO" in the project?
|
||||
|
||||
🤖 The user wants to find all instances of "TODO". The `grep` tool is perfect for this. I will use it to search the current directory.
|
||||
|
||||
> grep(pattern="TODO", path=".")
|
||||
|
||||
... (grep tool output) ...
|
||||
|
||||
🤖 I found the following "TODO" comments in your project.
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive Mode
|
||||
|
||||
Simply run `vibe` to enter the interactive chat loop.
|
||||
|
||||
- **Multi-line Input**: Press `Ctrl+J` or `Shift+Enter` for select terminals to insert a newline.
|
||||
- **File Paths**: Reference files in your prompt using the `@` symbol for smart autocompletion (e.g., `> Read the file @src/agent.py`).
|
||||
- **Shell Commands**: Prefix any command with `!` to execute it directly in your shell, bypassing the agent (e.g., `> !ls -l`).
|
||||
|
||||
You can start Vibe with a prompt with the following command:
|
||||
|
||||
```bash
|
||||
vibe "Refactor the main function in cli/main.py to be more modular."
|
||||
```
|
||||
|
||||
**Note**: The `--auto-approve` flag automatically approves all tool executions without prompting. In interactive mode, you can also toggle auto-approve on/off using `Shift+Tab`.
|
||||
|
||||
### Programmatic Mode
|
||||
|
||||
You can run Vibe non-interactively by piping input or using the `--prompt` flag. This is useful for scripting.
|
||||
|
||||
```bash
|
||||
vibe --prompt "Refactor the main function in cli/main.py to be more modular."
|
||||
```
|
||||
|
||||
by default it will use `auto-approve` mode.
|
||||
|
||||
### Slash Commands
|
||||
|
||||
Use slash commands for meta-actions and configuration changes during a session.
|
||||
|
||||
## Configuration
|
||||
|
||||
Vibe is configured via a `config.toml` file. It looks for this file first in `./.vibe/config.toml` and then falls back to `~/.vibe/config.toml`.
|
||||
|
||||
### API Key Configuration
|
||||
|
||||
Vibe supports multiple ways to configure your API keys:
|
||||
|
||||
1. **Interactive Setup (Recommended for first-time users)**: When you run Vibe for the first time or if your API key is missing, Vibe will prompt you to enter it. The key will be securely saved to `~/.vibe/.env` for future sessions.
|
||||
|
||||
2. **Environment Variables**: Set your API key as an environment variable:
|
||||
|
||||
```bash
|
||||
export MISTRAL_API_KEY="your_mistral_api_key"
|
||||
```
|
||||
|
||||
3. **`.env` File**: Create a `.env` file in `~/.vibe/` and add your API keys:
|
||||
|
||||
```bash
|
||||
MISTRAL_API_KEY=your_mistral_api_key
|
||||
```
|
||||
|
||||
Vibe automatically loads API keys from `~/.vibe/.env` on startup. Environment variables take precedence over the `.env` file if both are set.
|
||||
|
||||
**Note**: The `.env` file is specifically for API keys and other provider credentials. General Vibe configuration should be done in `config.toml`.
|
||||
|
||||
### Custom System Prompts
|
||||
|
||||
You can create custom system prompts to replace the default one (`prompts/core.md`). Create a markdown file in the `~/.vibe/prompts/` directory with your custom prompt content.
|
||||
|
||||
To use a custom system prompt, set the `system_prompt_id` in your configuration to match the filename (without the `.md` extension):
|
||||
|
||||
```toml
|
||||
# Use a custom system prompt
|
||||
system_prompt_id = "my_custom_prompt"
|
||||
```
|
||||
|
||||
This will load the prompt from `~/.vibe/prompts/my_custom_prompt.md`.
|
||||
|
||||
### Custom Agent Configurations
|
||||
|
||||
You can create custom agent configurations for specific use cases (e.g., red-teaming, specialized tasks) by adding agent-specific TOML files in the `~/.vibe/agents/` directory.
|
||||
|
||||
To use a custom agent, run Vibe with the `--agent` flag:
|
||||
|
||||
```bash
|
||||
vibe --agent my_custom_agent
|
||||
```
|
||||
|
||||
Vibe will look for a file named `my_custom_agent.toml` in the agents directory and apply its configuration.
|
||||
|
||||
Example custom agent configuration (`~/.vibe/agents/redteam.toml`):
|
||||
|
||||
```toml
|
||||
# Custom agent configuration for red-teaming
|
||||
active_model = "devstral-2"
|
||||
system_prompt_id = "redteam"
|
||||
|
||||
# Disable some tools for this agent
|
||||
disabled_tools = ["search_replace", "write_file"]
|
||||
|
||||
# Override tool permissions for this agent
|
||||
[tools.bash]
|
||||
permission = "always"
|
||||
|
||||
[tools.read_file]
|
||||
permission = "always"
|
||||
```
|
||||
|
||||
Note: this implies that you have setup a redteam prompt names `~/.vibe/prompts/redteam.md`
|
||||
|
||||
### MCP Server Configuration
|
||||
|
||||
You can configure MCP (Model Context Protocol) servers to extend Vibe's capabilities. Add MCP server configurations under the `mcp_servers` section:
|
||||
|
||||
```toml
|
||||
# Example MCP server configurations
|
||||
[[mcp_servers]]
|
||||
name = "my_http_server"
|
||||
transport = "http"
|
||||
url = "http://localhost:8000"
|
||||
headers = { "Authorization" = "Bearer my_token" }
|
||||
api_key_env = "MY_API_KEY_ENV_VAR"
|
||||
api_key_header = "Authorization"
|
||||
api_key_format = "Bearer {token}"
|
||||
|
||||
[[mcp_servers]]
|
||||
name = "my_streamable_server"
|
||||
transport = "streamable-http"
|
||||
url = "http://localhost:8001"
|
||||
headers = { "X-API-Key" = "my_api_key" }
|
||||
|
||||
[[mcp_servers]]
|
||||
name = "fetch_server"
|
||||
transport = "stdio"
|
||||
command = "uvx"
|
||||
args = ["mcp-server-fetch"]
|
||||
```
|
||||
|
||||
Supported transports:
|
||||
|
||||
- `http`: Standard HTTP transport
|
||||
- `streamable-http`: HTTP transport with streaming support
|
||||
- `stdio`: Standard input/output transport (for local processes)
|
||||
|
||||
Key fields:
|
||||
|
||||
- `name`: A short alias for the server (used in tool names)
|
||||
- `transport`: The transport type
|
||||
- `url`: Base URL for HTTP transports
|
||||
- `headers`: Additional HTTP headers
|
||||
- `api_key_env`: Environment variable containing the API key
|
||||
- `command`: Command to run for stdio transport
|
||||
- `args`: Additional arguments for stdio transport
|
||||
|
||||
### Enable/disable tools with patterns
|
||||
|
||||
You can control which tools are active using `enabled_tools` and `disabled_tools`.
|
||||
These fields support exact names, glob patterns, and regular expressions.
|
||||
|
||||
Examples:
|
||||
|
||||
```toml
|
||||
# Only enable tools that start with "serena_" (glob)
|
||||
enabled_tools = ["serena_*"]
|
||||
|
||||
# Regex (prefix with re:) — matches full tool name (case-insensitive)
|
||||
enabled_tools = ["re:^serena_.*$"]
|
||||
|
||||
# Heuristic regex support (patterns like `serena.*` are treated as regex)
|
||||
enabled_tools = ["serena.*"]
|
||||
|
||||
# Disable a group with glob; everything else stays enabled
|
||||
disabled_tools = ["mcp_*", "grep"]
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- MCP tool names use underscores, e.g., `serena_list` not `serena.list`.
|
||||
- Regex patterns are matched against the full tool name using fullmatch.
|
||||
|
||||
### Custom Vibe Home Directory
|
||||
|
||||
By default, Vibe stores its configuration in `~/.vibe/`. You can override this by setting the `VIBE_HOME` environment variable:
|
||||
|
||||
```bash
|
||||
export VIBE_HOME="/path/to/custom/vibe/home"
|
||||
```
|
||||
|
||||
This affects where Vibe looks for:
|
||||
|
||||
- `config.toml` - Main configuration
|
||||
- `.env` - API keys
|
||||
- `agents/` - Custom agent configurations
|
||||
- `prompts/` - Custom system prompts
|
||||
- `tools/` - Custom tools
|
||||
- `logs/` - Session logsRetryTo run code, enable code execution and file creation in Settings > Capabilities.
|
||||
|
||||
## Resources
|
||||
|
||||
- [CHANGELOG](CHANGELOG.md) - See what's new in each version
|
||||
- [CONTRIBUTING](CONTRIBUTING.md) - Guidelines for feedback and bug reports
|
||||
|
||||
## License
|
||||
|
||||
Copyright 2025 Mistral AI
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the [LICENSE](LICENSE) file for the full license text.
|
||||
64
action.yml
Normal file
64
action.yml
Normal file
@@ -0,0 +1,64 @@
|
||||
---
|
||||
name: Mistral Vibe
|
||||
description: "Download, install, and run Mistral Vibe"
|
||||
author: Mistral AI
|
||||
|
||||
inputs:
|
||||
prompt:
|
||||
description: The prompt to pass to the agent
|
||||
required: true
|
||||
default: |
|
||||
You are a helpful assistant
|
||||
MISTRAL_API_KEY:
|
||||
description: API Key for AI Studio
|
||||
required: true
|
||||
install_python:
|
||||
description: |
|
||||
Whether or not to install Python
|
||||
required: true
|
||||
default: "true"
|
||||
python_version:
|
||||
description: |
|
||||
Version of Python to install. Warning: Unsupported.
|
||||
required: false
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Install Required Python Version
|
||||
if: inputs.install_python == true
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version-file: ${{ github.action_path }}/.python-version
|
||||
|
||||
- name: Install Requested Python Version
|
||||
if: inputs.install_python && inputs.python_version
|
||||
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
|
||||
with:
|
||||
python-version: ${{ inputs.python_version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
||||
- name: Install Mistral Vibe
|
||||
shell: bash
|
||||
working-directory: ${{ github.action_path }}
|
||||
run: |
|
||||
uv sync --locked --all-extras --dev
|
||||
|
||||
- name: Run Mistral Vibe
|
||||
id: run-mistral-vibe
|
||||
shell: bash
|
||||
working-directory: ${{ github.action_path }}
|
||||
env:
|
||||
MISTRAL_API_KEY: "${{ inputs.MISTRAL_API_KEY }}"
|
||||
run: |
|
||||
# We want to make sure that any text passed in here
|
||||
# doesn't have special bash characters (<, >, &, etc...)
|
||||
ESCAPED_PROMPT=$(printf '%q' "${{ inputs.prompt }}")
|
||||
|
||||
# Change back to the original working directory for the tool to work
|
||||
cd "${{ github.workspace }}"
|
||||
uv run --directory "${{ github.action_path }}" vibe \
|
||||
--auto-approve \
|
||||
-p "${ESCAPED_PROMPT}"
|
||||
35
distribution/zed/extension.toml
Normal file
35
distribution/zed/extension.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
id = "mistral-vibe"
|
||||
name = "Mistral Vibe"
|
||||
description = "Lightning-fast AI agent that actually gets things done"
|
||||
version = "1.0.0"
|
||||
schema_version = 1
|
||||
authors = ["Mistral AI"]
|
||||
repository = "https://github.com/mistralai/mistral-vibe"
|
||||
|
||||
[agent_servers.mistral-vibe-agent]
|
||||
name = "Mistral Vibe"
|
||||
icon = "./icons/mistral_vibe.svg"
|
||||
|
||||
[agent_servers.mistral-vibe-agent.targets.darwin-aarch64]
|
||||
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-darwin-aarch64-1.0.0.zip"
|
||||
cmd = "./vibe-acp"
|
||||
|
||||
[agent_servers.mistral-vibe-agent.targets.darwin-x86_64]
|
||||
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-darwin-x86_64-1.0.0.zip"
|
||||
cmd = "./vibe-acp"
|
||||
|
||||
# [agent_servers.mistral-vibe-agent.targets.linux-aarch64]
|
||||
# archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-linux-aarch64-1.0.0.zip"
|
||||
# cmd = "./vibe-acp"
|
||||
|
||||
[agent_servers.mistral-vibe-agent.targets.linux-x86_64]
|
||||
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-linux-x86_64-1.0.0.zip"
|
||||
cmd = "./vibe-acp"
|
||||
|
||||
# [agent_servers.mistral-vibe-agent.targets.windows-aarch64]
|
||||
# archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-windows-aarch64-1.0.0.zip"
|
||||
# cmd = "./vibe-acp.exe"
|
||||
|
||||
[agent_servers.mistral-vibe-agent.targets.windows-x86_64]
|
||||
archive = "https://github.com/mistralai/mistral-vibe/releases/download/v1.0.0/vibe-acp-windows-x86_64-1.0.0.zip"
|
||||
cmd = "./vibe-acp.exe"
|
||||
13
distribution/zed/icons/mistral_vibe.svg
Normal file
13
distribution/zed/icons/mistral_vibe.svg
Normal file
@@ -0,0 +1,13 @@
|
||||
<svg width="512" height="512" viewBox="0 0 512 512" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect width="512" height="512" fill="#24211E"/>
|
||||
<rect x="113.778" y="113.778" width="56.8889" height="56.8889" fill="white"/>
|
||||
<rect x="341.333" y="113.778" width="56.8889" height="56.8889" fill="white"/>
|
||||
<rect x="227.556" y="284.444" width="56.8889" height="56.8889" fill="white"/>
|
||||
<rect x="113.778" y="284.444" width="56.8889" height="56.8889" fill="white"/>
|
||||
<rect x="341.333" y="284.444" width="56.8889" height="56.8889" fill="white"/>
|
||||
<rect x="113.778" y="170.667" width="113.778" height="56.8889" fill="white"/>
|
||||
<rect x="56.8887" y="341.333" width="170.667" height="56.8889" fill="white"/>
|
||||
<rect x="284.444" y="341.333" width="170.667" height="56.8889" fill="white"/>
|
||||
<rect x="113.778" y="227.556" width="284.444" height="56.8889" fill="white"/>
|
||||
<rect x="284.444" y="170.667" width="113.778" height="56.8889" fill="white"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 935 B |
133
flake.lock
generated
Normal file
133
flake.lock
generated
Normal file
@@ -0,0 +1,133 @@
|
||||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1763283776,
|
||||
"narHash": "sha256-Y7TDFPK4GlqrKrivOcsHG8xSGqQx3A6c+i7novT85Uk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "50a96edd8d0db6cc8db57dab6bb6d6ee1f3dc49a",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761781027,
|
||||
"narHash": "sha256-YDvxPAm2WnxrznRqWwHLjryBGG5Ey1ATEJXrON+TWt8=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "795a980d25301e5133eca37adae37283ec3c8e66",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763349549,
|
||||
"narHash": "sha256-GQKYN9j8HOh09RW2I739tyu87ygcsAmpJJ32FspWVJ8=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "071b718279182c5585f74939c2902c202f93f588",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
144
flake.nix
Normal file
144
flake.nix
Normal file
@@ -0,0 +1,144 @@
|
||||
{
|
||||
description = "Mistral Vibe!";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
};
|
||||
|
||||
outputs = {
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
uv2nix,
|
||||
pyproject-nix,
|
||||
pyproject-build-systems,
|
||||
...
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (system: let
|
||||
inherit (nixpkgs) lib;
|
||||
|
||||
workspace = uv2nix.lib.workspace.loadWorkspace {workspaceRoot = ./.;};
|
||||
|
||||
overlay = workspace.mkPyprojectOverlay {
|
||||
sourcePreference = "wheel"; # sdist if you want
|
||||
};
|
||||
|
||||
pyprojectOverrides = final: prev: {
|
||||
# NOTE: If a package complains about a missing dependency (such
|
||||
# as setuptools), you can add it here.
|
||||
untokenize = prev.untokenize.overrideAttrs (old: {
|
||||
buildInputs = (old.buildInputs or []) ++ final.resolveBuildSystem {setuptools = [];};
|
||||
});
|
||||
};
|
||||
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
};
|
||||
|
||||
python = pkgs.python312;
|
||||
|
||||
# Construct package set
|
||||
pythonSet =
|
||||
# Use base package set from pyproject.nix builders
|
||||
(pkgs.callPackage pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope
|
||||
(
|
||||
lib.composeManyExtensions [
|
||||
pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
pyprojectOverrides
|
||||
]
|
||||
);
|
||||
in {
|
||||
packages.default = pythonSet.mkVirtualEnv "mistralai-vibe-env" workspace.deps.default;
|
||||
|
||||
apps = {
|
||||
default = {
|
||||
type = "app";
|
||||
program = "${self.packages.${system}.default}/bin/vibe";
|
||||
};
|
||||
};
|
||||
|
||||
devShells = {
|
||||
default = let
|
||||
editableOverlay = workspace.mkEditablePyprojectOverlay {
|
||||
root = "$REPO_ROOT";
|
||||
};
|
||||
|
||||
editablePythonSet = pythonSet.overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
editableOverlay
|
||||
|
||||
# Apply fixups for building an editable package of your workspace packages
|
||||
(final: prev: {
|
||||
mistralai-vibe = prev.mistralai-vibe.overrideAttrs (old: {
|
||||
# It's a good idea to filter the sources going into an editable build
|
||||
# so the editable package doesn't have to be rebuilt on every change.
|
||||
src = lib.fileset.toSource {
|
||||
root = old.src;
|
||||
fileset = lib.fileset.unions [
|
||||
(old.src + "/pyproject.toml")
|
||||
(old.src + "/README.md")
|
||||
];
|
||||
};
|
||||
|
||||
nativeBuildInputs =
|
||||
old.nativeBuildInputs
|
||||
++ final.resolveBuildSystem {
|
||||
editables = [];
|
||||
};
|
||||
});
|
||||
})
|
||||
]
|
||||
);
|
||||
|
||||
virtualenv = editablePythonSet.mkVirtualEnv "mistralai-vibe-dev-env" workspace.deps.all;
|
||||
in
|
||||
pkgs.mkShell {
|
||||
packages = [
|
||||
virtualenv
|
||||
pkgs.uv
|
||||
];
|
||||
|
||||
env = {
|
||||
# Don't create venv using uv
|
||||
UV_NO_SYNC = "1";
|
||||
|
||||
# Force uv to use Python interpreter from venv
|
||||
UV_PYTHON = "${virtualenv}/bin/python";
|
||||
|
||||
# Prevent uv from downloading managed Python's
|
||||
UV_PYTHON_DOWNLOADS = "never";
|
||||
};
|
||||
|
||||
shellHook = ''
|
||||
# Undo dependency propagation by nixpkgs.
|
||||
unset PYTHONPATH
|
||||
|
||||
# Get repository root using git. This is expanded at runtime by the editable `.pth` machinery.
|
||||
export REPO_ROOT=$(git rev-parse --show-toplevel)
|
||||
'';
|
||||
};
|
||||
};
|
||||
});
|
||||
}
|
||||
157
pyproject.toml
Normal file
157
pyproject.toml
Normal file
@@ -0,0 +1,157 @@
|
||||
[project]
|
||||
name = "mistral-vibe"
|
||||
version = "1.0.0"
|
||||
description = "Minimal CLI coding agent by Mistral"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
license = { text = "Apache-2.0" }
|
||||
authors = [{ name = "Mistral AI" }]
|
||||
keywords = [
|
||||
"ai",
|
||||
"cli",
|
||||
"coding-assistant",
|
||||
"mistral",
|
||||
"llm",
|
||||
"developer-tools",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Software Development",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Utilities",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-client-protocol==0.6.3",
|
||||
"aiofiles>=24.1.0",
|
||||
"httpx>=0.28.1",
|
||||
"mcp>=1.14.0",
|
||||
"mistralai==1.9.11",
|
||||
"pexpect>=4.9.0",
|
||||
"packaging>=24.1",
|
||||
"pydantic>=2.12.4",
|
||||
"pydantic-settings>=2.12.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
"pytest-xdist>=3.8.0",
|
||||
"rich>=14.0.0",
|
||||
"textual>=1.0.0",
|
||||
"tomli-w>=1.2.0",
|
||||
"watchfiles>=1.1.1",
|
||||
"pyperclip>=1.11.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/mistralai/mistral-vibe"
|
||||
Repository = "https://github.com/mistralai/mistral-vibe"
|
||||
Issues = "https://github.com/mistralai/mistral-vibe/issues"
|
||||
Documentation = "https://github.com/mistralai/mistral-vibe#readme"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling", "hatch-vcs", "editables"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
include = ["vibe/"]
|
||||
|
||||
|
||||
[project.scripts]
|
||||
vibe = "vibe.cli.entrypoint:main"
|
||||
vibe-acp = "vibe.acp.entrypoint:main"
|
||||
|
||||
|
||||
[tool.uv]
|
||||
package = true
|
||||
required-version = ">=0.8.0"
|
||||
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pre-commit>=4.2.0",
|
||||
"pyright>=1.1.403",
|
||||
"pytest>=8.3.5",
|
||||
"pytest-asyncio>=1.2.0",
|
||||
"pytest-timeout>=2.4.0",
|
||||
"pytest-textual-snapshot>=1.1.0",
|
||||
"respx>=0.22.0",
|
||||
"ruff>=0.14.5",
|
||||
"twine>=5.0.0",
|
||||
"typos>=1.34.0",
|
||||
"vulture>=2.14",
|
||||
"pyinstaller>=6.17.0",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
pythonVersion = "3.12"
|
||||
reportMissingTypeStubs = false
|
||||
reportPrivateImportUsage = false
|
||||
include = ["vibe/**/*.py", "tests/**/*.py"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
|
||||
[tool.ruff]
|
||||
include = ["vibe/**/*.py", "tests/**/*.py"]
|
||||
line-length = 88
|
||||
target-version = "py312"
|
||||
preview = true
|
||||
|
||||
[tool.ruff.format]
|
||||
skip-magic-trailing-comma = true
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"F",
|
||||
"I",
|
||||
"D2",
|
||||
"UP",
|
||||
"TID",
|
||||
"ANN",
|
||||
"PLR",
|
||||
"B0",
|
||||
"B905",
|
||||
"DOC102",
|
||||
"RUF022",
|
||||
"RUF010",
|
||||
"RUF012",
|
||||
"RUF019",
|
||||
"RUF100",
|
||||
]
|
||||
ignore = ["D203", "D205", "D213", "ANN401", "PLR6301"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["ANN", "PLR"]
|
||||
|
||||
[tool.ruff.lint.flake8-tidy-imports]
|
||||
ban-relative-imports = "all"
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["vibe"]
|
||||
force-sort-within-sections = true
|
||||
split-on-trailing-comma = true
|
||||
combine-as-imports = true
|
||||
force-wrap-aliases = false
|
||||
order-by-type = true
|
||||
required-imports = ["from __future__ import annotations"]
|
||||
|
||||
[tool.ruff.lint.pylint]
|
||||
max-statements = 50
|
||||
max-branches = 15
|
||||
max-locals = 15
|
||||
max-args = 9
|
||||
max-returns = 6
|
||||
max-nested-blocks = 4
|
||||
|
||||
[tool.vulture]
|
||||
ignore_decorators = ["@*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "-vvvv -q -n auto --durations=5 --import-mode=importlib"
|
||||
timeout = 10
|
||||
20
scripts/README.md
Normal file
20
scripts/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Project Management Scripts
|
||||
|
||||
This directory contains scripts that support project versioning and deployment workflows.
|
||||
|
||||
## Versioning
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Bump major version (1.0.0 -> 2.0.0)
|
||||
uv run scripts/bump_version.py major
|
||||
|
||||
# Bump minor version (1.0.0 -> 1.1.0)
|
||||
uv run scripts/bump_version.py minor
|
||||
|
||||
# Bump patch/micro version (1.0.0 -> 1.0.1)
|
||||
uv run scripts/bump_version.py micro
|
||||
# or
|
||||
uv run scripts/bump_version.py patch
|
||||
```
|
||||
138
scripts/bump_version.py
Executable file
138
scripts/bump_version.py
Executable file
@@ -0,0 +1,138 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Version bumping script for semver versioning.
|
||||
|
||||
This script increments the version in pyproject.toml based on the specified bump type:
|
||||
- major: 1.0.0 -> 2.0.0
|
||||
- minor: 1.0.0 -> 1.1.0
|
||||
- micro/patch: 1.0.0 -> 1.0.1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Literal, get_args
|
||||
|
||||
BumpType = Literal["major", "minor", "micro", "patch"]
|
||||
BUMP_TYPES = get_args(BumpType)
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||
match = re.match(r"^(\d+)\.(\d+)\.(\d+)$", version_str.strip())
|
||||
if not match:
|
||||
raise ValueError(f"Invalid version format: {version_str}")
|
||||
|
||||
return int(match.group(1)), int(match.group(2)), int(match.group(3))
|
||||
|
||||
|
||||
def format_version(major: int, minor: int, patch: int) -> str:
|
||||
return f"{major}.{minor}.{patch}"
|
||||
|
||||
|
||||
def bump_version(version: str, bump_type: BumpType) -> str:
|
||||
major, minor, patch = parse_version(version)
|
||||
|
||||
match bump_type:
|
||||
case "major":
|
||||
return format_version(major + 1, 0, 0)
|
||||
case "minor":
|
||||
return format_version(major, minor + 1, 0)
|
||||
case "micro" | "patch":
|
||||
return format_version(major, minor, patch + 1)
|
||||
|
||||
|
||||
def update_hard_values_files(filepath: str, patterns: list[tuple[str, str]]) -> None:
|
||||
path = Path(filepath)
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"{filepath} not found in current directory")
|
||||
|
||||
# Replace patterns
|
||||
for pattern, replacement in patterns:
|
||||
content = path.read_text()
|
||||
updated_content = re.sub(pattern, replacement, content, flags=re.MULTILINE)
|
||||
|
||||
if updated_content == content:
|
||||
raise ValueError(f"pattern {pattern} not found in {filepath}")
|
||||
|
||||
path.write_text(updated_content)
|
||||
|
||||
print(f"Updated version in {filepath}")
|
||||
|
||||
|
||||
def get_current_version() -> str:
|
||||
pyproject_path = Path("pyproject.toml")
|
||||
|
||||
if not pyproject_path.exists():
|
||||
raise FileNotFoundError("pyproject.toml not found in current directory")
|
||||
|
||||
content = pyproject_path.read_text()
|
||||
|
||||
# Find version line
|
||||
version_match = re.search(r'^version = "([^"]+)"$', content, re.MULTILINE)
|
||||
if not version_match:
|
||||
raise ValueError("Version not found in pyproject.toml")
|
||||
|
||||
return version_match.group(1)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Bump semver version in pyproject.toml",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
uv run scripts/bump_version.py major # 1.0.0 -> 2.0.0
|
||||
uv run scripts/bump_version.py minor # 1.0.0 -> 1.1.0
|
||||
uv run scripts/bump_version.py micro # 1.0.0 -> 1.0.1
|
||||
uv run scripts/bump_version.py patch # 1.0.0 -> 1.0.1
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"bump_type", choices=BUMP_TYPES, help="Type of version bump to perform"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Get current version
|
||||
current_version = get_current_version()
|
||||
print(f"Current version: {current_version}")
|
||||
|
||||
# Calculate new version
|
||||
new_version = bump_version(current_version, args.bump_type)
|
||||
print(f"New version: {new_version}")
|
||||
|
||||
# Update pyproject.toml
|
||||
update_hard_values_files(
|
||||
"pyproject.toml",
|
||||
[(f'version = "{current_version}"', f'version = "{new_version}"')],
|
||||
)
|
||||
# Update extension.toml
|
||||
update_hard_values_files(
|
||||
"distribution/zed/extension.toml",
|
||||
[
|
||||
(f'version = "{current_version}"', f'version = "{new_version}"'),
|
||||
(
|
||||
f"releases/download/v{current_version}",
|
||||
f"releases/download/v{new_version}",
|
||||
),
|
||||
(f"-{current_version}.zip", f"-{new_version}.zip"),
|
||||
],
|
||||
)
|
||||
|
||||
subprocess.run(["uv", "lock"], check=True)
|
||||
|
||||
print(f"Successfully bumped version from {current_version} to {new_version}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
128
scripts/install.sh
Executable file
128
scripts/install.sh
Executable file
@@ -0,0 +1,128 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Mistral Vibe Installation Script
|
||||
# This script installs uv if not present and then installs mistral-vibe using uv
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
BLUE='\033[0;34m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
function error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1" >&2
|
||||
}
|
||||
|
||||
function info() {
|
||||
echo -e "${BLUE}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
function success() {
|
||||
echo -e "${GREEN}[SUCCESS]${NC} $1"
|
||||
}
|
||||
|
||||
function warning() {
|
||||
echo -e "${YELLOW}[WARNING]${NC} $1"
|
||||
}
|
||||
|
||||
function check_platform() {
|
||||
local platform=$(uname -s)
|
||||
|
||||
if [[ "$platform" == "Linux" ]]; then
|
||||
info "Detected Linux platform"
|
||||
PLATFORM="linux"
|
||||
elif [[ "$platform" == "Darwin" ]]; then
|
||||
info "Detected macOS platform"
|
||||
PLATFORM="macos"
|
||||
else
|
||||
error "Unsupported platform: $platform"
|
||||
error "This installation script currently only supports Linux and macOS"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
function check_uv_installed() {
|
||||
if command -v uv &> /dev/null; then
|
||||
info "uv is already installed: $(uv --version)"
|
||||
UV_INSTALLED=true
|
||||
else
|
||||
info "uv is not installed"
|
||||
UV_INSTALLED=false
|
||||
fi
|
||||
}
|
||||
|
||||
function install_uv() {
|
||||
info "Installing uv using the official Astral installer..."
|
||||
|
||||
if ! command -v curl &> /dev/null; then
|
||||
error "curl is required to install uv. Please install curl first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if curl -LsSf https://astral.sh/uv/install.sh | sh; then
|
||||
success "uv installed successfully"
|
||||
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
if ! command -v uv &> /dev/null; then
|
||||
warning "uv was installed but not found in PATH for this session"
|
||||
warning "You may need to restart your terminal or run:"
|
||||
warning " export PATH=\"\$HOME/.cargo/bin:\$HOME/.local/bin:\$PATH\""
|
||||
fi
|
||||
else
|
||||
error "Failed to install uv"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
function install_vibe() {
|
||||
info "Installing mistral-vibe from GitHub repository using uv..."
|
||||
uv tool install mistral-vibe
|
||||
|
||||
success "Mistral Vibe installed successfully! (commands: vibe, vibe-acp)"
|
||||
}
|
||||
|
||||
function main() {
|
||||
echo
|
||||
echo "██████████████████░░"
|
||||
echo "██████████████████░░"
|
||||
echo "████ ██████ ████░░"
|
||||
echo "████ ██ ████░░"
|
||||
echo "████ ████░░"
|
||||
echo "████ ██ ██ ████░░"
|
||||
echo "██ ██ ██░░"
|
||||
echo "██████████████████░░"
|
||||
echo "██████████████████░░"
|
||||
echo
|
||||
echo "Starting Mistral Vibe installation..."
|
||||
echo
|
||||
|
||||
check_platform
|
||||
|
||||
check_uv_installed
|
||||
|
||||
if [[ "$UV_INSTALLED" == "false" ]]; then
|
||||
install_uv
|
||||
fi
|
||||
|
||||
install_vibe
|
||||
|
||||
if command -v vibe &> /dev/null; then
|
||||
success "Installation completed successfully!"
|
||||
echo
|
||||
echo "You can now run vibe with:"
|
||||
echo " vibe"
|
||||
echo
|
||||
echo "Or for ACP mode:"
|
||||
echo " vibe-acp"
|
||||
else
|
||||
error "Installation completed but 'vibe' command not found"
|
||||
error "Please check your installation and PATH settings"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
main
|
||||
5
tests/__init__.py
Normal file
5
tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
TESTS_ROOT = Path(__file__).parent
|
||||
925
tests/acp/test_acp.py
Normal file
925
tests/acp/test_acp.py
Normal file
@@ -0,0 +1,925 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from acp import (
|
||||
InitializeRequest,
|
||||
NewSessionRequest,
|
||||
PromptRequest,
|
||||
ReadTextFileRequest,
|
||||
ReadTextFileResponse,
|
||||
RequestPermissionRequest,
|
||||
RequestPermissionResponse,
|
||||
WriteTextFileRequest,
|
||||
)
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AllowedOutcome,
|
||||
DeniedOutcome,
|
||||
Implementation,
|
||||
InitializeResponse,
|
||||
McpCapabilities,
|
||||
NewSessionResponse,
|
||||
PromptCapabilities,
|
||||
PromptResponse,
|
||||
SessionNotification,
|
||||
TextContentBlock,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
from tests import TESTS_ROOT
|
||||
from tests.mock.utils import get_mocking_env, mock_llm_chunk
|
||||
from vibe.acp.utils import ToolOption
|
||||
from vibe.core.types import FunctionCall, ToolCall
|
||||
|
||||
RESPONSE_TIMEOUT = 2.0
|
||||
MOCK_ENTRYPOINT_PATH = "tests/mock/mock_entrypoint.py"
|
||||
PLAYGROUND_DIR = TESTS_ROOT / "playground"
|
||||
|
||||
|
||||
class JsonRpcRequest(BaseModel):
|
||||
jsonrpc: str = "2.0"
|
||||
id: int | str
|
||||
method: str
|
||||
params: Any | None = None
|
||||
|
||||
|
||||
class JsonRpcError(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Any | None = None
|
||||
|
||||
|
||||
class JsonRpcResponse(BaseModel):
|
||||
jsonrpc: str = "2.0"
|
||||
id: int | str | None = None
|
||||
result: Any | None = None
|
||||
error: JsonRpcError | None = None
|
||||
|
||||
|
||||
class JsonRpcNotification(BaseModel):
|
||||
jsonrpc: str = "2.0"
|
||||
method: str
|
||||
params: Any | None = None
|
||||
|
||||
|
||||
type JsonRpcMessage = JsonRpcResponse | JsonRpcNotification | JsonRpcRequest
|
||||
|
||||
|
||||
class InitializeJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "initialize"
|
||||
params: InitializeRequest | None = None
|
||||
|
||||
|
||||
class InitializeJsonRpcResponse(JsonRpcResponse):
|
||||
result: InitializeResponse | None = None
|
||||
|
||||
|
||||
class NewSessionJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "session/new"
|
||||
params: NewSessionRequest | None = None
|
||||
|
||||
|
||||
class NewSessionJsonRpcResponse(JsonRpcResponse):
|
||||
result: NewSessionResponse | None = None
|
||||
|
||||
|
||||
class PromptJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "session/prompt"
|
||||
params: PromptRequest | None = None
|
||||
|
||||
|
||||
class PromptJsonRpcResponse(JsonRpcResponse):
|
||||
result: PromptResponse | None = None
|
||||
|
||||
|
||||
class UpdateJsonRpcNotification(JsonRpcNotification):
|
||||
method: str = "session/update"
|
||||
params: SessionNotification | None = None
|
||||
|
||||
|
||||
class RequestPermissionJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "session/request_permission"
|
||||
params: RequestPermissionRequest | None = None
|
||||
|
||||
|
||||
class RequestPermissionJsonRpcResponse(JsonRpcResponse):
|
||||
result: RequestPermissionResponse | None = None
|
||||
|
||||
|
||||
class ReadTextFileJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "fs/read_text_file"
|
||||
params: ReadTextFileRequest | None = None
|
||||
|
||||
|
||||
class ReadTextFileJsonRpcResponse(JsonRpcResponse):
|
||||
result: ReadTextFileResponse | None = None
|
||||
|
||||
|
||||
class WriteTextFileJsonRpcRequest(JsonRpcRequest):
|
||||
method: str = "fs/write_text_file"
|
||||
params: WriteTextFileRequest | None = None
|
||||
|
||||
|
||||
class WriteTextFileJsonRpcResponse(JsonRpcResponse):
|
||||
result: None = None
|
||||
|
||||
|
||||
async def get_acp_agent_process(
|
||||
mock: bool = True, mock_env: dict[str, str] | None = None
|
||||
) -> AsyncGenerator[asyncio.subprocess.Process]:
|
||||
current_env = os.environ.copy()
|
||||
cmd = ["uv", "run", MOCK_ENTRYPOINT_PATH if mock else "vibe-acp"]
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=TESTS_ROOT.parent,
|
||||
env={
|
||||
**current_env,
|
||||
**(mock_env or {}),
|
||||
**({"MISTRAL_API_KEY": "mock"} if mock else {}),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield process
|
||||
finally:
|
||||
# Cleanup
|
||||
if process.returncode is None:
|
||||
process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=0.5)
|
||||
except TimeoutError:
|
||||
process.kill()
|
||||
await process.wait()
|
||||
|
||||
|
||||
async def send_json_rpc(
|
||||
process: asyncio.subprocess.Process, message: JsonRpcMessage
|
||||
) -> None:
|
||||
if process.stdin is None:
|
||||
raise RuntimeError("Process stdin not available")
|
||||
|
||||
request = message.model_dump_json()
|
||||
request_json = request + "\n"
|
||||
process.stdin.write(request_json.encode())
|
||||
await process.stdin.drain()
|
||||
|
||||
|
||||
async def read_response(
|
||||
process: asyncio.subprocess.Process, timeout: float = RESPONSE_TIMEOUT
|
||||
) -> str | None:
|
||||
if process.stdout is None:
|
||||
raise RuntimeError("Process stdout not available")
|
||||
|
||||
try:
|
||||
# Keep reading lines until we find a valid JSON line
|
||||
while True:
|
||||
line = await asyncio.wait_for(process.stdout.readline(), timeout=timeout)
|
||||
|
||||
if not line:
|
||||
return None
|
||||
|
||||
line_str = line.decode().strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
json.loads(line_str)
|
||||
return line_str
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, skip it (it's a log message)
|
||||
continue
|
||||
except TimeoutError:
|
||||
return None
|
||||
|
||||
|
||||
async def read_response_for_id(
|
||||
process: asyncio.subprocess.Process,
|
||||
expected_id: int | str,
|
||||
timeout: float = RESPONSE_TIMEOUT,
|
||||
) -> str | None:
|
||||
loop = asyncio.get_running_loop()
|
||||
end_time = loop.time() + timeout
|
||||
|
||||
while (remaining := end_time - loop.time()) > 0:
|
||||
response = await read_response(process, timeout=remaining)
|
||||
if response is None:
|
||||
return None
|
||||
|
||||
response_json = json.loads(response)
|
||||
if response_json.get("id") == expected_id:
|
||||
return response
|
||||
print(
|
||||
f"Skipping response with id={response_json.get('id')}, expecting {expected_id}"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def read_multiple_responses(
|
||||
process: asyncio.subprocess.Process,
|
||||
max_count: int = 10,
|
||||
timeout_per_response: float = RESPONSE_TIMEOUT,
|
||||
) -> list[str]:
|
||||
responses = []
|
||||
for _ in range(max_count):
|
||||
response = await read_response(process, timeout=timeout_per_response)
|
||||
if response:
|
||||
responses.append(response)
|
||||
else:
|
||||
break
|
||||
return responses
|
||||
|
||||
|
||||
def parse_conversation(message_texts: list[str]) -> list[JsonRpcMessage]:
|
||||
parsed_messages: list[JsonRpcMessage] = []
|
||||
for message_text in message_texts:
|
||||
message_json = json.loads(message_text)
|
||||
cls = None
|
||||
has_method = message_json.get("method", None) is not None
|
||||
has_id = message_json.get("id", None) is not None
|
||||
has_result = message_json.get("result", None) is not None
|
||||
|
||||
is_request = has_method and has_id
|
||||
is_notification = has_method and not has_id
|
||||
is_response = has_result
|
||||
|
||||
if is_request:
|
||||
match message_json.get("method"):
|
||||
case "session/prompt":
|
||||
cls = PromptJsonRpcRequest
|
||||
case "session/request_permission":
|
||||
cls = RequestPermissionJsonRpcRequest
|
||||
case "fs/read_text_file":
|
||||
cls = ReadTextFileJsonRpcRequest
|
||||
case "fs/write_text_file":
|
||||
cls = WriteTextFileJsonRpcRequest
|
||||
elif is_notification:
|
||||
match message_json.get("method"):
|
||||
case "session/update":
|
||||
cls = UpdateJsonRpcNotification
|
||||
elif is_response:
|
||||
# For responses, since we don't know the method, we need to find
|
||||
# the matching request.
|
||||
matching_request = next(
|
||||
(
|
||||
m
|
||||
for m in parsed_messages
|
||||
if isinstance(m, JsonRpcRequest) and m.id == message_json.get("id")
|
||||
),
|
||||
None,
|
||||
)
|
||||
if matching_request is None:
|
||||
# No matching request found in the conversation, it most probably was
|
||||
# not included in the conversation. We use a generic response class.
|
||||
cls = JsonRpcResponse
|
||||
else:
|
||||
match matching_request.method:
|
||||
case "session/prompt":
|
||||
cls = PromptJsonRpcResponse
|
||||
case "session/request_permission":
|
||||
cls = RequestPermissionJsonRpcResponse
|
||||
case "fs/read_text_file":
|
||||
cls = ReadTextFileJsonRpcResponse
|
||||
case "fs/write_text_file":
|
||||
cls = WriteTextFileJsonRpcResponse
|
||||
if cls is None:
|
||||
raise ValueError(f"No valid message class found for {message_json}")
|
||||
parsed_messages.append(cls.model_validate(message_json))
|
||||
return parsed_messages
|
||||
|
||||
|
||||
async def initialize_session(acp_agent_process: asyncio.subprocess.Process) -> str:
|
||||
await send_json_rpc(
|
||||
acp_agent_process,
|
||||
InitializeJsonRpcRequest(id=1, params=InitializeRequest(protocolVersion=1)),
|
||||
)
|
||||
initialize_response = await read_response_for_id(
|
||||
acp_agent_process, expected_id=1, timeout=5.0
|
||||
)
|
||||
assert initialize_response is not None
|
||||
|
||||
await send_json_rpc(
|
||||
acp_agent_process,
|
||||
NewSessionJsonRpcRequest(
|
||||
id=2, params=NewSessionRequest(cwd=str(PLAYGROUND_DIR), mcpServers=[])
|
||||
),
|
||||
)
|
||||
session_response = await read_response_for_id(acp_agent_process, expected_id=2)
|
||||
assert session_response is not None
|
||||
session_response_json = json.loads(session_response)
|
||||
session_response_obj = NewSessionJsonRpcResponse.model_validate(
|
||||
session_response_json
|
||||
)
|
||||
assert session_response_obj.result is not None, "No result in response"
|
||||
return session_response_obj.result.sessionId
|
||||
|
||||
|
||||
class TestInitialization:
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_request_response(self) -> None:
|
||||
mock_env = get_mocking_env()
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
await send_json_rpc(
|
||||
process,
|
||||
InitializeJsonRpcRequest(
|
||||
id=1, params=InitializeRequest(protocolVersion=1)
|
||||
),
|
||||
)
|
||||
|
||||
text_response = await read_response(process, timeout=10.0)
|
||||
assert text_response is not None, "No response to initialize"
|
||||
response_json = json.loads(text_response)
|
||||
response = InitializeJsonRpcResponse.model_validate(response_json)
|
||||
assert response.error is None, f"JSON-RPC error: {response.error}"
|
||||
assert response.result is not None, "No result in response"
|
||||
assert response.result.protocolVersion == 1
|
||||
assert response.result.agentCapabilities == AgentCapabilities(
|
||||
loadSession=False,
|
||||
promptCapabilities=PromptCapabilities(
|
||||
audio=False, embeddedContext=True, image=False
|
||||
),
|
||||
mcpCapabilities=McpCapabilities(http=False, sse=False),
|
||||
)
|
||||
assert response.result.agentInfo == Implementation(
|
||||
name="@mistralai/mistral-vibe", title="Mistral Vibe", version="0.1.0"
|
||||
)
|
||||
vibe_setup_method = next(
|
||||
(
|
||||
method
|
||||
for method in response.result.authMethods or []
|
||||
if method.id == "vibe-setup"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert vibe_setup_method is not None, "vibe-setup auth not found"
|
||||
assert vibe_setup_method.field_meta is not None
|
||||
assert "terminal-auth" in vibe_setup_method.field_meta.keys()
|
||||
|
||||
|
||||
class TestSessionManagement:
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_sessions_unique_ids(self) -> None:
|
||||
mock_env = get_mocking_env(mock_chunks=[mock_llm_chunk() for _ in range(3)])
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
await send_json_rpc(
|
||||
process,
|
||||
InitializeJsonRpcRequest(
|
||||
id=1, params=InitializeRequest(protocolVersion=1)
|
||||
),
|
||||
)
|
||||
await read_response_for_id(process, expected_id=1, timeout=5.0)
|
||||
|
||||
session_ids = []
|
||||
for i in range(3):
|
||||
await send_json_rpc(
|
||||
process,
|
||||
NewSessionJsonRpcRequest(
|
||||
id=i + 2,
|
||||
params=NewSessionRequest(
|
||||
cwd=str(PLAYGROUND_DIR), mcpServers=[]
|
||||
),
|
||||
),
|
||||
)
|
||||
text_response = await read_response_for_id(
|
||||
process, expected_id=i + 2, timeout=RESPONSE_TIMEOUT
|
||||
)
|
||||
assert text_response is not None
|
||||
response_json = json.loads(text_response)
|
||||
response = NewSessionJsonRpcResponse.model_validate(response_json)
|
||||
assert response.error is None, f"JSON-RPC error: {response.error}"
|
||||
assert response.result is not None, "No result in response"
|
||||
session_ids.append(response.result.sessionId)
|
||||
|
||||
assert len(set(session_ids)) == 3
|
||||
|
||||
|
||||
class TestSessionUpdates:
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_message_chunk_structure(self) -> None:
|
||||
mock_env = get_mocking_env([mock_llm_chunk(content="Hi") for _ in range(2)])
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
# Check stderr for error details if process failed
|
||||
if process.returncode is not None and process.stderr:
|
||||
stderr_data = await process.stderr.read()
|
||||
if stderr_data:
|
||||
# Log stderr for debugging test failures
|
||||
pass # Could add proper logging here if needed
|
||||
|
||||
session_id = await initialize_session(process)
|
||||
|
||||
await send_json_rpc(
|
||||
process,
|
||||
PromptJsonRpcRequest(
|
||||
id=3,
|
||||
params=PromptRequest(
|
||||
sessionId=session_id,
|
||||
prompt=[TextContentBlock(type="text", text="Just say hi")],
|
||||
),
|
||||
),
|
||||
)
|
||||
text_response = await read_response(process)
|
||||
assert text_response is not None
|
||||
response = UpdateJsonRpcNotification.model_validate(
|
||||
json.loads(text_response)
|
||||
)
|
||||
|
||||
assert response.params is not None
|
||||
assert response.params.update.sessionUpdate == "agent_message_chunk"
|
||||
assert response.params.update.content is not None
|
||||
assert response.params.update.content.type == "text"
|
||||
assert response.params.update.content.text is not None
|
||||
assert response.params.update.content.text == "Hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_update_structure(self) -> None:
|
||||
mock_env = get_mocking_env([
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="grep", arguments='{"pattern": "auth"}'
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="bash",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The files containing the pattern 'auth' are ...",
|
||||
finish_reason="stop",
|
||||
),
|
||||
])
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
session_id = await initialize_session(process)
|
||||
|
||||
await send_json_rpc(
|
||||
process,
|
||||
PromptJsonRpcRequest(
|
||||
id=3,
|
||||
params=PromptRequest(
|
||||
sessionId=session_id,
|
||||
prompt=[
|
||||
TextContentBlock(
|
||||
type="text",
|
||||
text="Show me files that are related to auth",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=10)
|
||||
assert len(text_responses) > 0
|
||||
responses = [
|
||||
UpdateJsonRpcNotification.model_validate(json.loads(r))
|
||||
for r in text_responses
|
||||
]
|
||||
|
||||
tool_call = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.params is not None
|
||||
and r.params.update.sessionUpdate == "tool_call"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert tool_call is not None
|
||||
assert tool_call.params is not None
|
||||
assert tool_call.params.update is not None
|
||||
|
||||
assert tool_call.params.update.sessionUpdate == "tool_call"
|
||||
assert tool_call.params.update.kind == "search"
|
||||
assert tool_call.params.update.title == "grep: 'auth'"
|
||||
assert (
|
||||
tool_call.params.update.rawInput
|
||||
== '{"pattern":"auth","path":".","max_matches":null,"use_default_ignore":true}'
|
||||
)
|
||||
|
||||
|
||||
async def start_session_with_request_permission(
|
||||
process: asyncio.subprocess.Process, prompt: str
|
||||
) -> RequestPermissionJsonRpcRequest:
|
||||
session_id = await initialize_session(process)
|
||||
await send_json_rpc(
|
||||
process,
|
||||
PromptJsonRpcRequest(
|
||||
id=3,
|
||||
params=PromptRequest(
|
||||
sessionId=session_id,
|
||||
prompt=[TextContentBlock(type="text", text=prompt)],
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(
|
||||
process, max_count=15, timeout_per_response=2.0
|
||||
)
|
||||
|
||||
responses = parse_conversation(text_responses)
|
||||
last_response = responses[-1]
|
||||
|
||||
assert isinstance(last_response, RequestPermissionJsonRpcRequest)
|
||||
assert last_response.params is not None
|
||||
assert len(last_response.params.options) == 3
|
||||
return last_response
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Disabled until we have a way to properly mock the fs and acp interactions"
|
||||
)
|
||||
class TestToolCallStructure:
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_request_permission_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="write_file",
|
||||
arguments='{"path":"test.txt","content":"hello, world!"'
|
||||
',"overwrite":true}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="write_file",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
session_id = await initialize_session(process)
|
||||
await send_json_rpc(
|
||||
process,
|
||||
PromptJsonRpcRequest(
|
||||
id=3,
|
||||
params=PromptRequest(
|
||||
sessionId=session_id,
|
||||
prompt=[
|
||||
TextContentBlock(
|
||||
type="text",
|
||||
text="Create a new file named test.txt "
|
||||
"with content 'hello, world!'",
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=3)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
# Look for tool call request permission updates
|
||||
permission_requests = [
|
||||
r for r in responses if isinstance(r, RequestPermissionJsonRpcRequest)
|
||||
]
|
||||
|
||||
assert len(permission_requests) > 0, (
|
||||
"No tool call permission requests found"
|
||||
)
|
||||
|
||||
first_request = permission_requests[0]
|
||||
assert first_request.params is not None
|
||||
assert first_request.params.toolCall is not None
|
||||
assert first_request.params.toolCall.toolCallId is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_update_approved_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="write_file",
|
||||
arguments='{"path":"test.txt","content":"hello, world!"'
|
||||
',"overwrite":true}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="write_file",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The file test.txt has been created", finish_reason="stop"
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
permission_request = await start_session_with_request_permission(
|
||||
process, "Create a file named test.txt"
|
||||
)
|
||||
assert permission_request.params is not None
|
||||
selected_option_id = ToolOption.ALLOW_ONCE
|
||||
await send_json_rpc(
|
||||
process,
|
||||
RequestPermissionJsonRpcResponse(
|
||||
id=permission_request.id,
|
||||
result=RequestPermissionResponse(
|
||||
outcome=AllowedOutcome(
|
||||
outcome="selected", optionId=selected_option_id
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=7)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
approved_tool_call = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.method == "session/update"
|
||||
and r.params is not None
|
||||
and r.params.update is not None
|
||||
and r.params.update.sessionUpdate == "tool_call_update"
|
||||
and r.params.update.toolCallId
|
||||
== (permission_request.params.toolCall.toolCallId)
|
||||
and r.params.update.status == "completed"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert approved_tool_call is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_update_rejected_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="write_file",
|
||||
arguments='{"path":"test.txt","content":"hello, world!"'
|
||||
',"overwrite":false}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="write_file",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The file test.txt has not been created, "
|
||||
"because you rejected the permission request",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
permission_request = await start_session_with_request_permission(
|
||||
process, "Create a file named test.txt"
|
||||
)
|
||||
assert permission_request.params is not None
|
||||
selected_option_id = ToolOption.REJECT_ONCE
|
||||
|
||||
await send_json_rpc(
|
||||
process,
|
||||
RequestPermissionJsonRpcResponse(
|
||||
id=permission_request.id,
|
||||
result=RequestPermissionResponse(
|
||||
outcome=AllowedOutcome(
|
||||
outcome="selected", optionId=selected_option_id
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=5)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
rejected_tool_call = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.method == "session/update"
|
||||
and r.params is not None
|
||||
and r.params.update.sessionUpdate == "tool_call_update"
|
||||
and r.params.update.toolCallId
|
||||
== (permission_request.params.toolCall.toolCallId)
|
||||
and r.params.update.status == "failed"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert rejected_tool_call is not None
|
||||
|
||||
@pytest.mark.skip(reason="Long running tool call updates are not implemented yet")
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_in_progress_update_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="bash",
|
||||
arguments='{"command":"sleep 3","timeout":null}',
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
],
|
||||
name="bash",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The command sleep 3 has been run", finish_reason="stop"
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
session_id = await initialize_session(process)
|
||||
await send_json_rpc(
|
||||
process,
|
||||
PromptJsonRpcRequest(
|
||||
id=3,
|
||||
params=PromptRequest(
|
||||
sessionId=session_id,
|
||||
prompt=[
|
||||
TextContentBlock(
|
||||
type="text", text="Run sleep 3 in the current directory"
|
||||
)
|
||||
],
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=4)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
# Look for tool call in progress updates
|
||||
in_progress_calls = [
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.params is not None
|
||||
and r.params.update.sessionUpdate == "tool_call_update"
|
||||
and r.params.update.status == "in_progress"
|
||||
]
|
||||
|
||||
assert len(in_progress_calls) > 0, (
|
||||
"No tool call in progress updates found for a long running command"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_result_update_failure_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="write_file",
|
||||
arguments='{"path":"/test.txt","content":"hello, world!"'
|
||||
',"overwrite":true}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="write_file",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The file /test.txt has not been created "
|
||||
"because it's outside the project directory",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
permission_request = await start_session_with_request_permission(
|
||||
process, "Create a file named /test.txt"
|
||||
)
|
||||
assert permission_request.params is not None
|
||||
selected_option_id = ToolOption.ALLOW_ONCE
|
||||
await send_json_rpc(
|
||||
process,
|
||||
RequestPermissionJsonRpcResponse(
|
||||
id=permission_request.id,
|
||||
result=RequestPermissionResponse(
|
||||
outcome=AllowedOutcome(
|
||||
outcome="selected", optionId=selected_option_id
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=7)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
# Look for tool call result failure updates
|
||||
failure_result = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.params is not None
|
||||
and r.params.update.sessionUpdate == "tool_call_update"
|
||||
and r.params.update.status == "failed"
|
||||
and r.params.update.rawOutput is not None
|
||||
and r.params.update.toolCallId is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
assert failure_result is not None
|
||||
|
||||
|
||||
class TestCancellationStructure:
|
||||
@pytest.mark.skip(
|
||||
reason="Proper cancellation is not implemented yet, we still need to return "
|
||||
"the right end_turn and be able to cancel at any point in time "
|
||||
"(and not only at tool call time)"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_update_cancelled_structure(self) -> None:
|
||||
custom_results = [
|
||||
mock_llm_chunk(content="Hey"),
|
||||
mock_llm_chunk(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
function=FunctionCall(
|
||||
name="write_file",
|
||||
arguments='{"path":"test.txt","content":"hello, world!"'
|
||||
',"overwrite":false}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
name="write_file",
|
||||
finish_reason="tool_calls",
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="The file test.txt has not been created, "
|
||||
"because you cancelled the permission request",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]
|
||||
mock_env = get_mocking_env(custom_results)
|
||||
async for process in get_acp_agent_process(mock_env=mock_env):
|
||||
permission_request = await start_session_with_request_permission(
|
||||
process, "Create a file named test.txt"
|
||||
)
|
||||
assert permission_request.params is not None
|
||||
|
||||
await send_json_rpc(
|
||||
process,
|
||||
RequestPermissionJsonRpcResponse(
|
||||
id=permission_request.id,
|
||||
result=RequestPermissionResponse(
|
||||
outcome=DeniedOutcome(outcome="cancelled")
|
||||
),
|
||||
),
|
||||
)
|
||||
text_responses = await read_multiple_responses(process, max_count=5)
|
||||
responses = parse_conversation(text_responses)
|
||||
|
||||
assert len(responses) == 2, (
|
||||
"There should be only 2 responses: "
|
||||
"the tool call update and the prompt end turn"
|
||||
)
|
||||
|
||||
cancelled_tool_call = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, UpdateJsonRpcNotification)
|
||||
and r.method == "session/update"
|
||||
and r.params is not None
|
||||
and r.params.update.sessionUpdate == "tool_call_update"
|
||||
and r.params.update.toolCallId
|
||||
== (permission_request.params.toolCall.toolCallId)
|
||||
and r.params.update.status == "failed"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert cancelled_tool_call is not None
|
||||
|
||||
cancelled_prompt_response = next(
|
||||
(
|
||||
r
|
||||
for r in responses
|
||||
if isinstance(r, PromptJsonRpcResponse)
|
||||
and r.result is not None
|
||||
and r.result.stopReason == "cancelled"
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert cancelled_prompt_response is not None
|
||||
536
tests/acp/test_bash.py
Normal file
536
tests/acp/test_bash.py
Normal file
@@ -0,0 +1,536 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from acp.schema import TerminalOutputResponse, WaitForTerminalExitResponse
|
||||
import pytest
|
||||
|
||||
from vibe.acp.tools.builtins.bash import AcpBashState, Bash
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.bash import BashArgs, BashResult, BashToolConfig
|
||||
|
||||
|
||||
class MockTerminalHandle:
|
||||
def __init__(
|
||||
self,
|
||||
terminal_id: str = "test_terminal_123",
|
||||
exit_code: int | None = 0,
|
||||
output: str = "test output",
|
||||
wait_delay: float = 0.01,
|
||||
) -> None:
|
||||
self.id = terminal_id
|
||||
self._exit_code = exit_code
|
||||
self._output = output
|
||||
self._wait_delay = wait_delay
|
||||
self._killed = False
|
||||
|
||||
async def wait_for_exit(self) -> WaitForTerminalExitResponse:
|
||||
await asyncio.sleep(self._wait_delay)
|
||||
return WaitForTerminalExitResponse(exitCode=self._exit_code)
|
||||
|
||||
async def current_output(self) -> TerminalOutputResponse:
|
||||
return TerminalOutputResponse(output=self._output, truncated=False)
|
||||
|
||||
async def kill(self) -> None:
|
||||
self._killed = True
|
||||
|
||||
async def release(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockConnection:
|
||||
def __init__(self, terminal_handle: MockTerminalHandle | None = None) -> None:
|
||||
self._terminal_handle = terminal_handle or MockTerminalHandle()
|
||||
self._create_terminal_called = False
|
||||
self._session_update_called = False
|
||||
self._create_terminal_error: Exception | None = None
|
||||
self._last_create_request = None
|
||||
|
||||
async def createTerminal(self, request) -> MockTerminalHandle:
|
||||
self._create_terminal_called = True
|
||||
self._last_create_request = request
|
||||
if self._create_terminal_error:
|
||||
raise self._create_terminal_error
|
||||
return self._terminal_handle
|
||||
|
||||
async def sessionUpdate(self, notification) -> None:
|
||||
self._session_update_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection() -> MockConnection:
|
||||
return MockConnection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_bash_tool(mock_connection: MockConnection) -> Bash:
|
||||
config = BashToolConfig()
|
||||
# Use model_construct to bypass Pydantic validation for testing
|
||||
state = AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session_123",
|
||||
tool_call_id="test_tool_call_456",
|
||||
)
|
||||
return Bash(config=config, state=state)
|
||||
|
||||
|
||||
class TestAcpBashBasic:
|
||||
def test_get_name(self) -> None:
|
||||
assert Bash.get_name() == "bash"
|
||||
|
||||
def test_get_summary_simple_command(self) -> None:
|
||||
args = BashArgs(command="ls")
|
||||
display = Bash.get_summary(args)
|
||||
assert display == "ls"
|
||||
|
||||
def test_get_summary_with_timeout(self) -> None:
|
||||
args = BashArgs(command="ls", timeout=10)
|
||||
display = Bash.get_summary(args)
|
||||
assert display == "ls (timeout 10s)"
|
||||
|
||||
def test_parse_command_simple(self) -> None:
|
||||
tool = Bash(config=BashToolConfig(), state=AcpBashState())
|
||||
env, command, args = tool._parse_command("ls")
|
||||
assert env == []
|
||||
assert command == "ls"
|
||||
assert args == []
|
||||
|
||||
def test_parse_command_with_args(self) -> None:
|
||||
tool = Bash(config=BashToolConfig(), state=AcpBashState())
|
||||
env, command, args = tool._parse_command("ls -la src")
|
||||
assert env == []
|
||||
assert command == "ls"
|
||||
assert args == ["-la", "src"]
|
||||
|
||||
def test_parse_command_with_env(self) -> None:
|
||||
tool = Bash(config=BashToolConfig(), state=AcpBashState())
|
||||
env, command, args = tool._parse_command("NODE_ENV=test DEBUG=1 npm test")
|
||||
assert len(env) == 2
|
||||
assert env[0].name == "NODE_ENV"
|
||||
assert env[0].value == "test"
|
||||
assert env[1].name == "DEBUG"
|
||||
assert env[1].value == "1"
|
||||
assert command == "npm"
|
||||
assert args == ["test"]
|
||||
|
||||
def test_parse_command_with_env_value_contains_equals(self) -> None:
|
||||
tool = Bash(config=BashToolConfig(), state=AcpBashState())
|
||||
env, command, args = tool._parse_command(
|
||||
"PATH=/usr/bin:/usr/local/bin echo hello"
|
||||
)
|
||||
assert len(env) == 1
|
||||
assert env[0].name == "PATH"
|
||||
assert env[0].value == "/usr/bin:/usr/local/bin"
|
||||
assert command == "echo"
|
||||
assert args == ["hello"]
|
||||
|
||||
|
||||
class TestAcpBashExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success(
|
||||
self, acp_bash_tool: Bash, mock_connection: MockConnection
|
||||
) -> None:
|
||||
from pathlib import Path
|
||||
|
||||
args = BashArgs(command="echo hello")
|
||||
result = await acp_bash_tool.run(args)
|
||||
|
||||
assert isinstance(result, BashResult)
|
||||
assert result.stdout == "test output"
|
||||
assert result.stderr == ""
|
||||
assert result.returncode == 0
|
||||
assert mock_connection._create_terminal_called
|
||||
|
||||
# Verify CreateTerminalRequest was created correctly
|
||||
request = mock_connection._last_create_request
|
||||
assert request is not None
|
||||
assert request.sessionId == "test_session_123"
|
||||
assert request.command == "echo"
|
||||
assert request.args == ["hello"]
|
||||
assert request.cwd == str(Path.cwd()) # effective_workdir defaults to cwd
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_creates_terminal_with_env_vars(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="NODE_ENV=test npm run build")
|
||||
await tool.run(args)
|
||||
|
||||
request = mock_connection._last_create_request
|
||||
assert request is not None
|
||||
assert len(request.env) == 1
|
||||
assert request.env[0].name == "NODE_ENV"
|
||||
assert request.env[0].value == "test"
|
||||
assert request.command == "npm"
|
||||
assert request.args == ["run", "build"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_nonzero_exit_code(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="custom_terminal", exit_code=1, output="error: command failed"
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test_command")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Command failed: 'test_command'\nReturn code: 1\nStdout: error: command failed"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_create_terminal_failure(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
mock_connection._create_terminal_error = RuntimeError("Connection failed")
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Failed to create terminal: RuntimeError('Connection failed')"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_connection(self) -> None:
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=None, session_id="test_session", tool_call_id="test_call"
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Connection not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_session_id(self) -> None:
|
||||
mock_connection = MockConnection()
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id=None,
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Session ID not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_none_exit_code(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="none_exit_terminal", exit_code=None, output="output"
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test_command")
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == "output"
|
||||
|
||||
|
||||
class TestAcpBashTimeout:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_timeout_raises_error_and_kills(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="timeout_terminal",
|
||||
output="partial output",
|
||||
wait_delay=20, # Longer than the 1 second timeout
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
# Use a config with different default timeout to verify args timeout overrides it
|
||||
tool = Bash(
|
||||
config=BashToolConfig(default_timeout=30),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="slow_command", timeout=1)
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == "Command timed out after 1s: 'slow_command'"
|
||||
assert custom_handle._killed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_timeout_handles_kill_failure(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="kill_failure_terminal",
|
||||
wait_delay=20, # Longer than the 1 second timeout
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
async def failing_kill() -> None:
|
||||
raise RuntimeError("Kill failed")
|
||||
|
||||
custom_handle.kill = failing_kill
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="slow_command", timeout=1)
|
||||
# Should still raise timeout error even if kill fails
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == "Command timed out after 1s: 'slow_command'"
|
||||
|
||||
|
||||
class TestAcpBashEmbedding:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_embedding(self, mock_connection: MockConnection) -> None:
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
await tool.run(args)
|
||||
|
||||
assert mock_connection._session_update_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_embedding_without_tool_call_id(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id=None,
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
await tool.run(args)
|
||||
|
||||
# Embedding should be skipped when tool_call_id is None
|
||||
assert not mock_connection._session_update_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_embedding_handles_exception(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
# Make sessionUpdate raise an exception
|
||||
async def failing_session_update(notification) -> None:
|
||||
raise RuntimeError("Session update failed")
|
||||
|
||||
mock_connection.sessionUpdate = failing_session_update
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
# Should not raise, embedding failure is silently ignored
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result is not None
|
||||
assert result.stdout == "test output"
|
||||
|
||||
|
||||
class TestAcpBashConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_uses_config_default_timeout(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="config_timeout_terminal",
|
||||
wait_delay=0.01, # Shorter than config timeout
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(default_timeout=30),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="fast", timeout=None)
|
||||
result = await tool.run(args)
|
||||
|
||||
# Should succeed with config timeout
|
||||
assert result.returncode == 0
|
||||
|
||||
|
||||
class TestAcpBashCleanup:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_releases_terminal_on_success(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(terminal_id="cleanup_terminal")
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
release_called = False
|
||||
|
||||
async def mock_release() -> None:
|
||||
nonlocal release_called
|
||||
release_called = True
|
||||
|
||||
custom_handle.release = mock_release
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
await tool.run(args)
|
||||
|
||||
assert release_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_releases_terminal_on_timeout(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
# The handle will wait 2 seconds, but timeout is 1 second,
|
||||
# so asyncio.wait_for() will raise TimeoutError
|
||||
custom_handle = MockTerminalHandle(
|
||||
terminal_id="timeout_cleanup_terminal",
|
||||
wait_delay=2.0, # Longer than the 1 second timeout
|
||||
)
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
release_called = False
|
||||
|
||||
async def mock_release() -> None:
|
||||
nonlocal release_called
|
||||
release_called = True
|
||||
|
||||
custom_handle.release = mock_release
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="slow", timeout=1)
|
||||
# Timeout raises an error, but terminal should still be released
|
||||
try:
|
||||
await tool.run(args)
|
||||
except ToolError:
|
||||
pass
|
||||
|
||||
assert release_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_handles_release_failure(
|
||||
self, mock_connection: MockConnection
|
||||
) -> None:
|
||||
custom_handle = MockTerminalHandle(terminal_id="release_failure_terminal")
|
||||
|
||||
async def failing_release() -> None:
|
||||
raise RuntimeError("Release failed")
|
||||
|
||||
custom_handle.release = failing_release
|
||||
mock_connection._terminal_handle = custom_handle
|
||||
|
||||
tool = Bash(
|
||||
config=BashToolConfig(),
|
||||
state=AcpBashState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = BashArgs(command="test")
|
||||
# Should not raise, release failure is silently ignored
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result is not None
|
||||
assert result.stdout == "test output"
|
||||
184
tests/acp/test_content.py
Normal file
184
tests/acp/test_content.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from acp import AgentSideConnection, NewSessionRequest, PromptRequest
|
||||
from acp.schema import (
|
||||
EmbeddedResourceContentBlock,
|
||||
ResourceContentBlock,
|
||||
TextContentBlock,
|
||||
TextResourceContents,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_connection import FakeAgentSideConnection
|
||||
from vibe.acp.acp_agent import VibeAcpAgent
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend() -> FakeBackend:
|
||||
backend = FakeBackend(
|
||||
results=[
|
||||
LLMChunk(
|
||||
message=LLMMessage(role=Role.assistant, content="Hi"),
|
||||
finish_reason="end_turn",
|
||||
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
|
||||
class PatchedAgent(Agent):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs, backend=backend)
|
||||
|
||||
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
|
||||
|
||||
vibe_acp_agent: VibeAcpAgent | None = None
|
||||
|
||||
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
|
||||
nonlocal vibe_acp_agent
|
||||
vibe_acp_agent = VibeAcpAgent(connection)
|
||||
return vibe_acp_agent
|
||||
|
||||
FakeAgentSideConnection(_create_agent)
|
||||
return vibe_acp_agent # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class TestACPContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_content(
|
||||
self, acp_agent: VibeAcpAgent, backend: FakeBackend
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
prompt_request = PromptRequest(
|
||||
prompt=[TextContentBlock(type="text", text="Say hi")],
|
||||
sessionId=session_response.sessionId,
|
||||
)
|
||||
|
||||
response = await acp_agent.prompt(params=prompt_request)
|
||||
|
||||
assert response.stopReason == "end_turn"
|
||||
user_message = next(
|
||||
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
|
||||
None,
|
||||
)
|
||||
assert user_message is not None, "User message not found in backend requests"
|
||||
assert user_message.content == "Say hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_content(
|
||||
self, acp_agent: VibeAcpAgent, backend: FakeBackend
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
prompt_request = PromptRequest(
|
||||
prompt=[
|
||||
TextContentBlock(type="text", text="What does this file do?"),
|
||||
EmbeddedResourceContentBlock(
|
||||
type="resource",
|
||||
resource=TextResourceContents(
|
||||
uri="file:///home/my_file.py",
|
||||
text="def hello():\n print('Hello, world!')",
|
||||
mimeType="text/x-python",
|
||||
),
|
||||
),
|
||||
],
|
||||
sessionId=session_response.sessionId,
|
||||
)
|
||||
|
||||
response = await acp_agent.prompt(params=prompt_request)
|
||||
|
||||
assert response.stopReason == "end_turn"
|
||||
user_message = next(
|
||||
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
|
||||
None,
|
||||
)
|
||||
assert user_message is not None, "User message not found in backend requests"
|
||||
expected_content = (
|
||||
"What does this file do?"
|
||||
+ "\n\npath: file:///home/my_file.py"
|
||||
+ "\ncontent: def hello():\n print('Hello, world!')"
|
||||
)
|
||||
assert user_message.content == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_link_content(
|
||||
self, acp_agent: VibeAcpAgent, backend: FakeBackend
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
prompt_request = PromptRequest(
|
||||
prompt=[
|
||||
TextContentBlock(type="text", text="Analyze this resource"),
|
||||
ResourceContentBlock(
|
||||
type="resource_link",
|
||||
uri="file:///home/document.pdf",
|
||||
name="document.pdf",
|
||||
title="Important Document",
|
||||
description="A PDF document containing project specifications",
|
||||
mimeType="application/pdf",
|
||||
size=1024,
|
||||
),
|
||||
],
|
||||
sessionId=session_response.sessionId,
|
||||
)
|
||||
|
||||
response = await acp_agent.prompt(params=prompt_request)
|
||||
|
||||
assert response.stopReason == "end_turn"
|
||||
user_message = next(
|
||||
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
|
||||
None,
|
||||
)
|
||||
assert user_message is not None, "User message not found in backend requests"
|
||||
expected_content = (
|
||||
"Analyze this resource"
|
||||
+ "\n\nuri: file:///home/document.pdf"
|
||||
+ "\nname: document.pdf"
|
||||
+ "\ntitle: Important Document"
|
||||
+ "\ndescription: A PDF document containing project specifications"
|
||||
+ "\nmimeType: application/pdf"
|
||||
+ "\nsize: 1024"
|
||||
)
|
||||
assert user_message.content == expected_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resource_link_minimal(
|
||||
self, acp_agent: VibeAcpAgent, backend: FakeBackend
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
prompt_request = PromptRequest(
|
||||
prompt=[
|
||||
ResourceContentBlock(
|
||||
type="resource_link",
|
||||
uri="file:///home/minimal.txt",
|
||||
name="minimal.txt",
|
||||
)
|
||||
],
|
||||
sessionId=session_response.sessionId,
|
||||
)
|
||||
|
||||
response = await acp_agent.prompt(params=prompt_request)
|
||||
|
||||
assert response.stopReason == "end_turn"
|
||||
user_message = next(
|
||||
(msg for msg in backend._requests_messages[0] if msg.role == Role.user),
|
||||
None,
|
||||
)
|
||||
assert user_message is not None, "User message not found in backend requests"
|
||||
expected_content = "uri: file:///home/minimal.txt\nname: minimal.txt"
|
||||
assert user_message.content == expected_content
|
||||
161
tests/acp/test_multi_session.py
Normal file
161
tests/acp/test_multi_session.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from acp import (
|
||||
PROTOCOL_VERSION,
|
||||
InitializeRequest,
|
||||
NewSessionRequest,
|
||||
PromptRequest,
|
||||
RequestError,
|
||||
)
|
||||
from acp.schema import TextContentBlock
|
||||
import pytest
|
||||
from pytest import raises
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_connection import FakeAgentSideConnection
|
||||
from vibe.acp.acp_agent import VibeAcpAgent
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import ModelConfig, VibeConfig
|
||||
from vibe.core.types import Role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend() -> FakeBackend:
|
||||
backend = FakeBackend()
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
|
||||
config = VibeConfig(
|
||||
active_model="devstral-latest",
|
||||
models=[
|
||||
ModelConfig(
|
||||
name="devstral-latest", provider="mistral", alias="devstral-latest"
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
class PatchedAgent(Agent):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.backend = backend
|
||||
self.config = config
|
||||
|
||||
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
|
||||
|
||||
vibe_acp_agent: VibeAcpAgent | None = None
|
||||
|
||||
def _create_agent(connection: Any) -> VibeAcpAgent:
|
||||
nonlocal vibe_acp_agent
|
||||
vibe_acp_agent = VibeAcpAgent(connection)
|
||||
return vibe_acp_agent
|
||||
|
||||
FakeAgentSideConnection(_create_agent)
|
||||
return vibe_acp_agent # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class TestMultiSessionCore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_sessions_use_different_agents(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
|
||||
session1_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session1 = acp_agent.sessions[session1_response.sessionId]
|
||||
session2_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session2 = acp_agent.sessions[session2_response.sessionId]
|
||||
|
||||
assert session1.id != session2.id
|
||||
# Each agent should be independent
|
||||
assert session1.agent is not session2.agent
|
||||
assert id(session1.agent) != id(session2.agent)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_on_nonexistent_session(self, acp_agent: VibeAcpAgent) -> None:
|
||||
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
|
||||
await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
|
||||
fake_session_id = "fake-session-id-" + str(uuid4())
|
||||
|
||||
with raises(RequestError) as exc_info:
|
||||
await acp_agent.prompt(
|
||||
PromptRequest(
|
||||
sessionId=fake_session_id,
|
||||
prompt=[TextContentBlock(type="text", text="Hello, world!")],
|
||||
)
|
||||
)
|
||||
|
||||
assert isinstance(exc_info.value, RequestError)
|
||||
assert str(exc_info.value) == "Invalid params"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simultaneous_message_processing(
|
||||
self, acp_agent: VibeAcpAgent, backend: FakeBackend
|
||||
) -> None:
|
||||
await acp_agent.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION))
|
||||
session1_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session1 = acp_agent.sessions[session1_response.sessionId]
|
||||
session2_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session2 = acp_agent.sessions[session2_response.sessionId]
|
||||
|
||||
backend._chunks = [
|
||||
mock_llm_chunk(content="Response 1", finish_reason="stop"),
|
||||
mock_llm_chunk(content="Response 2", finish_reason="stop"),
|
||||
]
|
||||
|
||||
async def run_session1():
|
||||
await acp_agent.prompt(
|
||||
PromptRequest(
|
||||
sessionId=session1.id,
|
||||
prompt=[TextContentBlock(type="text", text="Prompt for session 1")],
|
||||
)
|
||||
)
|
||||
|
||||
async def run_session2():
|
||||
await acp_agent.prompt(
|
||||
PromptRequest(
|
||||
sessionId=session2.id,
|
||||
prompt=[TextContentBlock(type="text", text="Prompt for session 2")],
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(run_session1(), run_session2())
|
||||
|
||||
user_message1 = next(
|
||||
(msg for msg in session1.agent.messages if msg.role == Role.user), None
|
||||
)
|
||||
assert user_message1 is not None
|
||||
assert user_message1.content == "Prompt for session 1"
|
||||
assistant_message1 = next(
|
||||
(msg for msg in session1.agent.messages if msg.role == Role.assistant), None
|
||||
)
|
||||
assert assistant_message1 is not None
|
||||
assert assistant_message1.content == "Response 1"
|
||||
user_message2 = next(
|
||||
(msg for msg in session2.agent.messages if msg.role == Role.user), None
|
||||
)
|
||||
assert user_message2 is not None
|
||||
assert user_message2.content == "Prompt for session 2"
|
||||
assistant_message2 = next(
|
||||
(msg for msg in session2.agent.messages if msg.role == Role.assistant), None
|
||||
)
|
||||
assert assistant_message2 is not None
|
||||
assert assistant_message2.content == "Response 2"
|
||||
140
tests/acp/test_new_session.py
Normal file
140
tests/acp/test_new_session.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from acp import AgentSideConnection, NewSessionRequest, SetSessionModelRequest
|
||||
import pytest
|
||||
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_connection import FakeAgentSideConnection
|
||||
from vibe.acp.acp_agent import VibeAcpAgent
|
||||
from vibe.acp.utils import VibeSessionMode
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import ModelConfig, VibeConfig
|
||||
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend() -> FakeBackend:
|
||||
backend = FakeBackend(
|
||||
results=[
|
||||
LLMChunk(
|
||||
message=LLMMessage(role=Role.assistant, content="Hi"),
|
||||
finish_reason="end_turn",
|
||||
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
|
||||
config = VibeConfig(
|
||||
active_model="devstral-latest",
|
||||
models=[
|
||||
ModelConfig(
|
||||
name="devstral-latest", provider="mistral", alias="devstral-latest"
|
||||
),
|
||||
ModelConfig(
|
||||
name="devstral-small", provider="mistral", alias="devstral-small"
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
class PatchedAgent(Agent):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **{**kwargs, "backend": backend})
|
||||
self.config = config
|
||||
|
||||
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
|
||||
|
||||
vibe_acp_agent: VibeAcpAgent | None = None
|
||||
|
||||
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
|
||||
nonlocal vibe_acp_agent
|
||||
vibe_acp_agent = VibeAcpAgent(connection)
|
||||
return vibe_acp_agent
|
||||
|
||||
FakeAgentSideConnection(_create_agent)
|
||||
return vibe_acp_agent # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class TestACPNewSession:
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_response_structure(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
|
||||
assert session_response.sessionId is not None
|
||||
acp_session = next(
|
||||
(
|
||||
s
|
||||
for s in acp_agent.sessions.values()
|
||||
if s.id == session_response.sessionId
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert acp_session is not None
|
||||
assert (
|
||||
acp_session.agent.interaction_logger.session_id
|
||||
== session_response.sessionId
|
||||
)
|
||||
|
||||
assert session_response.sessionId == acp_session.agent.session_id
|
||||
|
||||
assert session_response.models is not None
|
||||
assert session_response.models.currentModelId is not None
|
||||
assert session_response.models.availableModels is not None
|
||||
assert len(session_response.models.availableModels) == 2
|
||||
|
||||
assert session_response.models.currentModelId == "devstral-latest"
|
||||
assert session_response.models.availableModels[0].modelId == "devstral-latest"
|
||||
assert session_response.models.availableModels[0].name == "devstral-latest"
|
||||
assert session_response.models.availableModels[1].modelId == "devstral-small"
|
||||
assert session_response.models.availableModels[1].name == "devstral-small"
|
||||
|
||||
assert session_response.modes is not None
|
||||
assert session_response.modes.currentModeId is not None
|
||||
assert session_response.modes.availableModes is not None
|
||||
assert len(session_response.modes.availableModes) == 2
|
||||
|
||||
assert session_response.modes.currentModeId == VibeSessionMode.APPROVAL_REQUIRED
|
||||
assert (
|
||||
session_response.modes.availableModes[0].id
|
||||
== VibeSessionMode.APPROVAL_REQUIRED
|
||||
)
|
||||
assert session_response.modes.availableModes[0].name == "Approval Required"
|
||||
assert (
|
||||
session_response.modes.availableModes[1].id == VibeSessionMode.AUTO_APPROVE
|
||||
)
|
||||
assert session_response.modes.availableModes[1].name == "Auto Approve"
|
||||
|
||||
@pytest.mark.skip(reason="TODO: Fix this test")
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_session_preserves_model_after_set_model(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
|
||||
assert session_response.models is not None
|
||||
assert session_response.models.currentModelId == "devstral-latest"
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
|
||||
assert session_response.models is not None
|
||||
assert session_response.models.currentModelId == "devstral-small"
|
||||
240
tests/acp/test_read_file.py
Normal file
240
tests/acp/test_read_file.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import ReadTextFileRequest, ReadTextFileResponse
|
||||
import pytest
|
||||
|
||||
from vibe.acp.tools.builtins.read_file import AcpReadFileState, ReadFile
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.read_file import (
|
||||
ReadFileArgs,
|
||||
ReadFileResult,
|
||||
ReadFileToolConfig,
|
||||
)
|
||||
|
||||
|
||||
class MockConnection:
|
||||
def __init__(
|
||||
self,
|
||||
file_content: str = "line 1\nline 2\nline 3",
|
||||
read_error: Exception | None = None,
|
||||
) -> None:
|
||||
self._file_content = file_content
|
||||
self._read_error = read_error
|
||||
self._read_text_file_called = False
|
||||
self._session_update_called = False
|
||||
self._last_read_request: ReadTextFileRequest | None = None
|
||||
|
||||
async def readTextFile(self, request: ReadTextFileRequest) -> ReadTextFileResponse:
|
||||
self._read_text_file_called = True
|
||||
self._last_read_request = request
|
||||
|
||||
if self._read_error:
|
||||
raise self._read_error
|
||||
|
||||
content = self._file_content
|
||||
if request.line is not None or request.limit is not None:
|
||||
lines = content.splitlines(keepends=True)
|
||||
start_line = (request.line or 1) - 1 # Convert to 0-indexed
|
||||
end_line = (
|
||||
start_line + request.limit if request.limit is not None else len(lines)
|
||||
)
|
||||
lines = lines[start_line:end_line]
|
||||
content = "".join(lines)
|
||||
|
||||
return ReadTextFileResponse(content=content)
|
||||
|
||||
async def sessionUpdate(self, notification) -> None:
|
||||
self._session_update_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection() -> MockConnection:
|
||||
return MockConnection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_read_file_tool(mock_connection: MockConnection, tmp_path: Path) -> ReadFile:
|
||||
config = ReadFileToolConfig(workdir=tmp_path)
|
||||
state = AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session_123",
|
||||
tool_call_id="test_tool_call_456",
|
||||
)
|
||||
return ReadFile(config=config, state=state)
|
||||
|
||||
|
||||
class TestAcpReadFileBasic:
|
||||
def test_get_name(self) -> None:
|
||||
assert ReadFile.get_name() == "read_file"
|
||||
|
||||
|
||||
class TestAcpReadFileExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success(
|
||||
self,
|
||||
acp_read_file_tool: ReadFile,
|
||||
mock_connection: MockConnection,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.touch()
|
||||
args = ReadFileArgs(path=str(test_file))
|
||||
result = await acp_read_file_tool.run(args)
|
||||
|
||||
assert isinstance(result, ReadFileResult)
|
||||
assert result.path == str(test_file)
|
||||
assert result.content == "line 1\nline 2\nline 3"
|
||||
assert result.lines_read == 3
|
||||
assert mock_connection._read_text_file_called
|
||||
assert mock_connection._session_update_called
|
||||
|
||||
# Verify ReadTextFileRequest was created correctly
|
||||
request = mock_connection._last_read_request
|
||||
assert request is not None
|
||||
assert request.sessionId == "test_session_123"
|
||||
assert request.path == str(test_file)
|
||||
assert request.line is None # offset=0 means no line specified
|
||||
assert request.limit is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_offset(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.touch()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file), offset=1)
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result.lines_read == 2
|
||||
assert result.content == "line 2\nline 3"
|
||||
|
||||
request = mock_connection._last_read_request
|
||||
assert request is not None
|
||||
assert request.line == 2 # offset=1 means line 2 (1-indexed)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_limit(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.touch()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file), limit=2)
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result.lines_read == 2
|
||||
assert result.content == "line 1\nline 2\n"
|
||||
|
||||
request = mock_connection._last_read_request
|
||||
assert request is not None
|
||||
assert request.limit == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_offset_and_limit(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.touch()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file), offset=1, limit=1)
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result.lines_read == 1
|
||||
assert result.content == "line 2\n"
|
||||
|
||||
request = mock_connection._last_read_request
|
||||
assert request is not None
|
||||
assert request.line == 2
|
||||
assert request.limit == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_read_error(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
mock_connection._read_error = RuntimeError("File not found")
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file))
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == f"Error reading {test_file}: File not found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_connection(self, tmp_path: Path) -> None:
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=None, session_id="test_session", tool_call_id="test_call"
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file))
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Connection not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_session_id(self, tmp_path: Path) -> None:
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
mock_connection = MockConnection()
|
||||
tool = ReadFile(
|
||||
config=ReadFileToolConfig(workdir=tmp_path),
|
||||
state=AcpReadFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id=None,
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = ReadFileArgs(path=str(test_file))
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Session ID not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
339
tests/acp/test_search_replace.py
Normal file
339
tests/acp/test_search_replace.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import ReadTextFileRequest, ReadTextFileResponse, WriteTextFileRequest
|
||||
import pytest
|
||||
|
||||
from vibe.acp.tools.builtins.search_replace import AcpSearchReplaceState, SearchReplace
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.search_replace import (
|
||||
SearchReplaceArgs,
|
||||
SearchReplaceConfig,
|
||||
SearchReplaceResult,
|
||||
)
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
|
||||
|
||||
class MockConnection:
|
||||
def __init__(
|
||||
self,
|
||||
file_content: str = "original line 1\noriginal line 2\noriginal line 3",
|
||||
read_error: Exception | None = None,
|
||||
write_error: Exception | None = None,
|
||||
) -> None:
|
||||
self._file_content = file_content
|
||||
self._read_error = read_error
|
||||
self._write_error = write_error
|
||||
self._read_text_file_called = False
|
||||
self._write_text_file_called = False
|
||||
self._session_update_called = False
|
||||
self._last_read_request: ReadTextFileRequest | None = None
|
||||
self._last_write_request: WriteTextFileRequest | None = None
|
||||
self._write_calls: list[WriteTextFileRequest] = []
|
||||
|
||||
async def readTextFile(self, request: ReadTextFileRequest) -> ReadTextFileResponse:
|
||||
self._read_text_file_called = True
|
||||
self._last_read_request = request
|
||||
|
||||
if self._read_error:
|
||||
raise self._read_error
|
||||
|
||||
return ReadTextFileResponse(content=self._file_content)
|
||||
|
||||
async def writeTextFile(self, request: WriteTextFileRequest) -> None:
|
||||
self._write_text_file_called = True
|
||||
self._last_write_request = request
|
||||
self._write_calls.append(request)
|
||||
|
||||
if self._write_error:
|
||||
raise self._write_error
|
||||
|
||||
async def sessionUpdate(self, notification) -> None:
|
||||
self._session_update_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection() -> MockConnection:
|
||||
return MockConnection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_search_replace_tool(
|
||||
mock_connection: MockConnection, tmp_path: Path
|
||||
) -> SearchReplace:
|
||||
config = SearchReplaceConfig(workdir=tmp_path)
|
||||
state = AcpSearchReplaceState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session_123",
|
||||
tool_call_id="test_tool_call_456",
|
||||
)
|
||||
return SearchReplace(config=config, state=state)
|
||||
|
||||
|
||||
class TestAcpSearchReplaceBasic:
|
||||
def test_get_name(self) -> None:
|
||||
assert SearchReplace.get_name() == "search_replace"
|
||||
|
||||
|
||||
class TestAcpSearchReplaceExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success(
|
||||
self,
|
||||
acp_search_replace_tool: SearchReplace,
|
||||
mock_connection: MockConnection,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.write_text("original line 1\noriginal line 2\noriginal line 3")
|
||||
search_replace_content = (
|
||||
"<<<<<<< SEARCH\noriginal line 2\n=======\nmodified line 2\n>>>>>>> REPLACE"
|
||||
)
|
||||
args = SearchReplaceArgs(
|
||||
file_path=str(test_file), content=search_replace_content
|
||||
)
|
||||
result = await acp_search_replace_tool.run(args)
|
||||
|
||||
assert isinstance(result, SearchReplaceResult)
|
||||
assert result.file == str(test_file)
|
||||
assert result.blocks_applied == 1
|
||||
assert mock_connection._read_text_file_called
|
||||
assert mock_connection._write_text_file_called
|
||||
assert mock_connection._session_update_called
|
||||
|
||||
# Verify ReadTextFileRequest was created correctly
|
||||
read_request = mock_connection._last_read_request
|
||||
assert read_request is not None
|
||||
assert read_request.sessionId == "test_session_123"
|
||||
assert read_request.path == str(test_file)
|
||||
|
||||
# Verify WriteTextFileRequest was created correctly
|
||||
write_request = mock_connection._last_write_request
|
||||
assert write_request is not None
|
||||
assert write_request.sessionId == "test_session_123"
|
||||
assert write_request.path == str(test_file)
|
||||
assert (
|
||||
write_request.content == "original line 1\nmodified line 2\noriginal line 3"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_backup(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
config = SearchReplaceConfig(create_backup=True, workdir=tmp_path)
|
||||
tool = SearchReplace(
|
||||
config=config,
|
||||
state=AcpSearchReplaceState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
test_file.write_text("original line 1\noriginal line 2\noriginal line 3")
|
||||
search_replace_content = (
|
||||
"<<<<<<< SEARCH\noriginal line 1\n=======\nmodified line 1\n>>>>>>> REPLACE"
|
||||
)
|
||||
args = SearchReplaceArgs(
|
||||
file_path=str(test_file), content=search_replace_content
|
||||
)
|
||||
result = await tool.run(args)
|
||||
|
||||
assert result.blocks_applied == 1
|
||||
# Should have written the main file and the backup
|
||||
assert len(mock_connection._write_calls) >= 1
|
||||
# Check if backup was written (it should be written to .bak file)
|
||||
assert sum(w.path.endswith(".bak") for w in mock_connection._write_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_read_error(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
mock_connection._read_error = RuntimeError("File not found")
|
||||
|
||||
tool = SearchReplace(
|
||||
config=SearchReplaceConfig(workdir=tmp_path),
|
||||
state=AcpSearchReplaceState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
|
||||
args = SearchReplaceArgs(
|
||||
file_path=str(test_file), content=search_replace_content
|
||||
)
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== f"Unexpected error reading {test_file}: File not found"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_write_error(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
mock_connection._write_error = RuntimeError("Permission denied")
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
mock_connection._file_content = "old" # Update mock to return correct content
|
||||
|
||||
tool = SearchReplace(
|
||||
config=SearchReplaceConfig(workdir=tmp_path),
|
||||
state=AcpSearchReplaceState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
|
||||
args = SearchReplaceArgs(
|
||||
file_path=str(test_file), content=search_replace_content
|
||||
)
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == f"Error writing {test_file}: Permission denied"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"connection,session_id,expected_error",
|
||||
[
|
||||
(
|
||||
None,
|
||||
"test_session",
|
||||
"Connection not available in tool state. This tool can only be used within an ACP session.",
|
||||
),
|
||||
(
|
||||
MockConnection(),
|
||||
None,
|
||||
"Session ID not available in tool state. This tool can only be used within an ACP session.",
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_run_without_required_state(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
connection: MockConnection | None,
|
||||
session_id: str | None,
|
||||
expected_error: str,
|
||||
) -> None:
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.touch()
|
||||
tool = SearchReplace(
|
||||
config=SearchReplaceConfig(workdir=tmp_path),
|
||||
state=AcpSearchReplaceState.model_construct(
|
||||
connection=connection, # type: ignore[arg-type]
|
||||
session_id=session_id,
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
search_replace_content = "<<<<<<< SEARCH\nold\n=======\nnew\n>>>>>>> REPLACE"
|
||||
args = SearchReplaceArgs(
|
||||
file_path=str(test_file), content=search_replace_content
|
||||
)
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
|
||||
class TestAcpSearchReplaceSessionUpdates:
|
||||
def test_tool_call_session_update(self) -> None:
|
||||
search_replace_content = (
|
||||
"<<<<<<< SEARCH\nold text\n=======\nnew text\n>>>>>>> REPLACE"
|
||||
)
|
||||
event = ToolCallEvent(
|
||||
tool_name="search_replace",
|
||||
tool_call_id="test_call_123",
|
||||
args=SearchReplaceArgs(
|
||||
file_path="/tmp/test.txt", content=search_replace_content
|
||||
),
|
||||
tool_class=SearchReplace,
|
||||
)
|
||||
|
||||
update = SearchReplace.tool_call_session_update(event)
|
||||
assert update is not None
|
||||
assert update.sessionUpdate == "tool_call"
|
||||
assert update.toolCallId == "test_call_123"
|
||||
assert update.kind == "edit"
|
||||
assert update.title is not None
|
||||
assert update.content is not None
|
||||
assert len(update.content) == 1
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "/tmp/test.txt"
|
||||
assert update.content[0].oldText == "old text"
|
||||
assert update.content[0].newText == "new text"
|
||||
assert update.locations is not None
|
||||
assert len(update.locations) == 1
|
||||
assert update.locations[0].path == "/tmp/test.txt"
|
||||
|
||||
def test_tool_call_session_update_invalid_args(self) -> None:
|
||||
class InvalidArgs:
|
||||
pass
|
||||
|
||||
event = ToolCallEvent.model_construct(
|
||||
tool_name="search_replace",
|
||||
tool_call_id="test_call_123",
|
||||
args=InvalidArgs(), # type: ignore[arg-type]
|
||||
tool_class=SearchReplace,
|
||||
)
|
||||
|
||||
update = SearchReplace.tool_call_session_update(event)
|
||||
assert update is None
|
||||
|
||||
def test_tool_result_session_update(self) -> None:
|
||||
search_replace_content = (
|
||||
"<<<<<<< SEARCH\nold text\n=======\nnew text\n>>>>>>> REPLACE"
|
||||
)
|
||||
result = SearchReplaceResult(
|
||||
file="/tmp/test.txt",
|
||||
blocks_applied=1,
|
||||
lines_changed=1,
|
||||
content=search_replace_content,
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
event = ToolResultEvent(
|
||||
tool_name="search_replace",
|
||||
tool_call_id="test_call_123",
|
||||
result=result,
|
||||
tool_class=SearchReplace,
|
||||
)
|
||||
|
||||
update = SearchReplace.tool_result_session_update(event)
|
||||
assert update is not None
|
||||
assert update.sessionUpdate == "tool_call_update"
|
||||
assert update.toolCallId == "test_call_123"
|
||||
assert update.status == "completed"
|
||||
assert update.content is not None
|
||||
assert len(update.content) == 1
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "/tmp/test.txt"
|
||||
assert update.content[0].oldText == "old text"
|
||||
assert update.content[0].newText == "new text"
|
||||
assert update.locations is not None
|
||||
assert len(update.locations) == 1
|
||||
assert update.locations[0].path == "/tmp/test.txt"
|
||||
|
||||
def test_tool_result_session_update_invalid_result(self) -> None:
|
||||
class InvalidResult:
|
||||
pass
|
||||
|
||||
event = ToolResultEvent.model_construct(
|
||||
tool_name="search_replace",
|
||||
tool_call_id="test_call_123",
|
||||
result=InvalidResult(), # type: ignore[arg-type]
|
||||
tool_class=SearchReplace,
|
||||
)
|
||||
|
||||
update = SearchReplace.tool_result_session_update(event)
|
||||
assert update is None
|
||||
165
tests/acp/test_set_mode.py
Normal file
165
tests/acp/test_set_mode.py
Normal file
@@ -0,0 +1,165 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from acp import AgentSideConnection, NewSessionRequest, SetSessionModeRequest
|
||||
import pytest
|
||||
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_connection import FakeAgentSideConnection
|
||||
from vibe.acp.acp_agent import VibeAcpAgent
|
||||
from vibe.acp.utils import VibeSessionMode
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend() -> FakeBackend:
|
||||
backend = FakeBackend(
|
||||
results=[
|
||||
LLMChunk(
|
||||
message=LLMMessage(role=Role.assistant, content="Hi"),
|
||||
finish_reason="end_turn",
|
||||
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
|
||||
class PatchedAgent(Agent):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs, backend=backend)
|
||||
|
||||
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
|
||||
|
||||
vibe_acp_agent: VibeAcpAgent | None = None
|
||||
|
||||
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
|
||||
nonlocal vibe_acp_agent
|
||||
vibe_acp_agent = VibeAcpAgent(connection)
|
||||
return vibe_acp_agent
|
||||
|
||||
FakeAgentSideConnection(_create_agent)
|
||||
return vibe_acp_agent # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class TestACPSetMode:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mode_to_approval_required(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
acp_session.agent.auto_approve = True
|
||||
acp_session.mode_id = VibeSessionMode.AUTO_APPROVE
|
||||
|
||||
response = await acp_agent.setSessionMode(
|
||||
SetSessionModeRequest(
|
||||
sessionId=session_id, modeId=VibeSessionMode.APPROVAL_REQUIRED
|
||||
)
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert acp_session.mode_id == VibeSessionMode.APPROVAL_REQUIRED
|
||||
assert acp_session.agent.auto_approve is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mode_to_AUTO_APPROVE(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
assert acp_session.mode_id == VibeSessionMode.APPROVAL_REQUIRED
|
||||
assert acp_session.agent.auto_approve is False
|
||||
|
||||
response = await acp_agent.setSessionMode(
|
||||
SetSessionModeRequest(
|
||||
sessionId=session_id, modeId=VibeSessionMode.AUTO_APPROVE
|
||||
)
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert acp_session.mode_id == VibeSessionMode.AUTO_APPROVE
|
||||
assert acp_session.agent.auto_approve is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mode_invalid_mode_returns_none(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
initial_mode_id = acp_session.mode_id
|
||||
initial_auto_approve = acp_session.agent.auto_approve
|
||||
|
||||
response = await acp_agent.setSessionMode(
|
||||
SetSessionModeRequest(sessionId=session_id, modeId="invalid-mode")
|
||||
)
|
||||
|
||||
assert response is None
|
||||
assert acp_session.mode_id == initial_mode_id
|
||||
assert acp_session.agent.auto_approve == initial_auto_approve
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mode_to_same_mode(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
initial_mode_id = VibeSessionMode.APPROVAL_REQUIRED
|
||||
assert acp_session.mode_id == initial_mode_id
|
||||
|
||||
response = await acp_agent.setSessionMode(
|
||||
SetSessionModeRequest(sessionId=session_id, modeId=initial_mode_id)
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert acp_session.mode_id == initial_mode_id
|
||||
assert acp_session.agent.auto_approve is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_mode_with_empty_string(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
initial_mode_id = acp_session.mode_id
|
||||
initial_auto_approve = acp_session.agent.auto_approve
|
||||
|
||||
response = await acp_agent.setSessionMode(
|
||||
SetSessionModeRequest(sessionId=session_id, modeId="")
|
||||
)
|
||||
|
||||
assert response is None
|
||||
assert acp_session.mode_id == initial_mode_id
|
||||
assert acp_session.agent.auto_approve == initial_auto_approve
|
||||
308
tests/acp/test_set_model.py
Normal file
308
tests/acp/test_set_model.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from acp import AgentSideConnection, NewSessionRequest, SetSessionModelRequest
|
||||
import pytest
|
||||
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_connection import FakeAgentSideConnection
|
||||
from vibe.acp.acp_agent import VibeAcpAgent
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import ModelConfig, VibeConfig
|
||||
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def backend() -> FakeBackend:
|
||||
backend = FakeBackend(
|
||||
results=[
|
||||
LLMChunk(
|
||||
message=LLMMessage(role=Role.assistant, content="Hi"),
|
||||
finish_reason="end_turn",
|
||||
usage=LLMUsage(prompt_tokens=1, completion_tokens=1),
|
||||
)
|
||||
]
|
||||
)
|
||||
return backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_agent(backend: FakeBackend) -> VibeAcpAgent:
|
||||
config = VibeConfig(
|
||||
active_model="devstral-latest",
|
||||
models=[
|
||||
ModelConfig(
|
||||
name="devstral-latest",
|
||||
provider="mistral",
|
||||
alias="devstral-latest",
|
||||
input_price=0.4,
|
||||
output_price=2.0,
|
||||
),
|
||||
ModelConfig(
|
||||
name="devstral-small",
|
||||
provider="mistral",
|
||||
alias="devstral-small",
|
||||
input_price=0.1,
|
||||
output_price=0.3,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
class PatchedAgent(Agent):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **{**kwargs, "backend": backend})
|
||||
self.config = config
|
||||
try:
|
||||
active_model = config.get_active_model()
|
||||
self.stats.input_price_per_million = active_model.input_price
|
||||
self.stats.output_price_per_million = active_model.output_price
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
patch("vibe.acp.acp_agent.VibeAgent", side_effect=PatchedAgent).start()
|
||||
|
||||
vibe_acp_agent: VibeAcpAgent | None = None
|
||||
|
||||
def _create_agent(connection: AgentSideConnection) -> VibeAcpAgent:
|
||||
nonlocal vibe_acp_agent
|
||||
vibe_acp_agent = VibeAcpAgent(connection)
|
||||
return vibe_acp_agent
|
||||
|
||||
FakeAgentSideConnection(_create_agent)
|
||||
return vibe_acp_agent # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
class TestACPSetModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_success(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
assert acp_session.agent.config.active_model == "devstral-latest"
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert acp_session.agent.config.active_model == "devstral-small"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_invalid_model_returns_none(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
initial_model = acp_session.agent.config.active_model
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="non-existent-model")
|
||||
)
|
||||
|
||||
assert response is None
|
||||
assert acp_session.agent.config.active_model == initial_model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_to_same_model(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
initial_model = "devstral-latest"
|
||||
assert acp_session is not None
|
||||
assert acp_session.agent.config.active_model == initial_model
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId=initial_model)
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert acp_session.agent.config.active_model == initial_model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_saves_to_config(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
|
||||
with patch("vibe.acp.acp_agent.VibeConfig.save_updates") as mock_save:
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
mock_save.assert_called_once_with({"active_model": "devstral-small"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_does_not_save_on_invalid_model(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
|
||||
with patch("vibe.acp.acp_agent.VibeConfig.save_updates") as mock_save:
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(
|
||||
sessionId=session_id, modelId="non-existent-model"
|
||||
)
|
||||
)
|
||||
|
||||
assert response is None
|
||||
mock_save.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_with_empty_string(self, acp_agent: VibeAcpAgent) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
initial_model = acp_session.agent.config.active_model
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="")
|
||||
)
|
||||
|
||||
assert response is None
|
||||
assert acp_session.agent.config.active_model == initial_model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_updates_active_model(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
assert acp_session.agent.config.get_active_model().alias == "devstral-latest"
|
||||
|
||||
await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert acp_session.agent.config.get_active_model().alias == "devstral-small"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_calls_reload_with_initial_messages(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
with patch.object(
|
||||
acp_session.agent, "reload_with_initial_messages"
|
||||
) as mock_reload:
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
mock_reload.assert_called_once()
|
||||
call_args = mock_reload.call_args
|
||||
assert call_args.kwargs["config"] is not None
|
||||
assert call_args.kwargs["config"].active_model == "devstral-small"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_preserves_conversation_history(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
user_msg = LLMMessage(role=Role.user, content="Hello")
|
||||
assistant_msg = LLMMessage(role=Role.assistant, content="Hi there!")
|
||||
acp_session.agent.messages.append(user_msg)
|
||||
acp_session.agent.messages.append(assistant_msg)
|
||||
|
||||
assert len(acp_session.agent.messages) == 3
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert len(acp_session.agent.messages) == 3
|
||||
assert acp_session.agent.messages[0].role == Role.system
|
||||
assert acp_session.agent.messages[1].content == "Hello"
|
||||
assert acp_session.agent.messages[2].content == "Hi there!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_model_resets_stats_with_new_model_pricing(
|
||||
self, acp_agent: VibeAcpAgent
|
||||
) -> None:
|
||||
session_response = await acp_agent.newSession(
|
||||
NewSessionRequest(cwd=str(Path.cwd()), mcpServers=[])
|
||||
)
|
||||
session_id = session_response.sessionId
|
||||
acp_session = next(
|
||||
(s for s in acp_agent.sessions.values() if s.id == session_id), None
|
||||
)
|
||||
assert acp_session is not None
|
||||
|
||||
initial_model = acp_session.agent.config.get_active_model()
|
||||
initial_input_price = initial_model.input_price
|
||||
initial_output_price = initial_model.output_price
|
||||
|
||||
initial_stats_input = acp_session.agent.stats.input_price_per_million
|
||||
initial_stats_output = acp_session.agent.stats.output_price_per_million
|
||||
|
||||
assert acp_session.agent.stats.input_price_per_million == initial_input_price
|
||||
assert acp_session.agent.stats.output_price_per_million == initial_output_price
|
||||
|
||||
response = await acp_agent.setSessionModel(
|
||||
SetSessionModelRequest(sessionId=session_id, modelId="devstral-small")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
|
||||
new_model = acp_session.agent.config.get_active_model()
|
||||
new_input_price = new_model.input_price
|
||||
new_output_price = new_model.output_price
|
||||
|
||||
assert new_input_price != initial_input_price
|
||||
assert new_output_price != initial_output_price
|
||||
|
||||
assert acp_session.agent.stats.input_price_per_million == new_input_price
|
||||
assert acp_session.agent.stats.output_price_per_million == new_output_price
|
||||
|
||||
assert acp_session.agent.stats.input_price_per_million != initial_stats_input
|
||||
assert acp_session.agent.stats.output_price_per_million != initial_stats_output
|
||||
269
tests/acp/test_write_file.py
Normal file
269
tests/acp/test_write_file.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import WriteTextFileRequest
|
||||
import pytest
|
||||
|
||||
from vibe.acp.tools.builtins.write_file import AcpWriteFileState, WriteFile
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.write_file import (
|
||||
WriteFileArgs,
|
||||
WriteFileConfig,
|
||||
WriteFileResult,
|
||||
)
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
|
||||
|
||||
class MockConnection:
|
||||
def __init__(
|
||||
self, write_error: Exception | None = None, file_exists: bool = False
|
||||
) -> None:
|
||||
self._write_error = write_error
|
||||
self._file_exists = file_exists
|
||||
self._write_text_file_called = False
|
||||
self._session_update_called = False
|
||||
self._last_write_request: WriteTextFileRequest | None = None
|
||||
|
||||
async def writeTextFile(self, request: WriteTextFileRequest) -> None:
|
||||
self._write_text_file_called = True
|
||||
self._last_write_request = request
|
||||
|
||||
if self._write_error:
|
||||
raise self._write_error
|
||||
|
||||
async def sessionUpdate(self, notification) -> None:
|
||||
self._session_update_called = True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_connection() -> MockConnection:
|
||||
return MockConnection()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def acp_write_file_tool(mock_connection: MockConnection, tmp_path: Path) -> WriteFile:
|
||||
config = WriteFileConfig(workdir=tmp_path)
|
||||
state = AcpWriteFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session_123",
|
||||
tool_call_id="test_tool_call_456",
|
||||
)
|
||||
return WriteFile(config=config, state=state)
|
||||
|
||||
|
||||
class TestAcpWriteFileBasic:
|
||||
def test_get_name(self) -> None:
|
||||
assert WriteFile.get_name() == "write_file"
|
||||
|
||||
|
||||
class TestAcpWriteFileExecution:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success_new_file(
|
||||
self,
|
||||
acp_write_file_tool: WriteFile,
|
||||
mock_connection: MockConnection,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
test_file = tmp_path / "test_file.txt"
|
||||
args = WriteFileArgs(path=str(test_file), content="Hello, world!")
|
||||
result = await acp_write_file_tool.run(args)
|
||||
|
||||
assert isinstance(result, WriteFileResult)
|
||||
assert result.path == str(test_file)
|
||||
assert result.content == "Hello, world!"
|
||||
assert result.bytes_written == len(b"Hello, world!")
|
||||
assert result.file_existed is False
|
||||
assert mock_connection._write_text_file_called
|
||||
assert mock_connection._session_update_called
|
||||
|
||||
# Verify WriteTextFileRequest was created correctly
|
||||
request = mock_connection._last_write_request
|
||||
assert request is not None
|
||||
assert request.sessionId == "test_session_123"
|
||||
assert request.path == str(test_file)
|
||||
assert request.content == "Hello, world!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_success_overwrite(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
tool = WriteFile(
|
||||
config=WriteFileConfig(workdir=tmp_path),
|
||||
state=AcpWriteFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
test_file = tmp_path / "existing_file.txt"
|
||||
test_file.touch()
|
||||
# Simulate existing file by checking in the core tool logic
|
||||
# The ACP tool doesn't check existence, it's handled by the core tool
|
||||
args = WriteFileArgs(path=str(test_file), content="New content", overwrite=True)
|
||||
result = await tool.run(args)
|
||||
|
||||
assert isinstance(result, WriteFileResult)
|
||||
assert result.path == str(test_file)
|
||||
assert result.content == "New content"
|
||||
assert result.bytes_written == len(b"New content")
|
||||
assert result.file_existed is True
|
||||
assert mock_connection._write_text_file_called
|
||||
assert mock_connection._session_update_called
|
||||
|
||||
# Verify WriteTextFileRequest was created correctly
|
||||
request = mock_connection._last_write_request
|
||||
assert request is not None
|
||||
assert request.sessionId == "test_session"
|
||||
assert request.path == str(test_file)
|
||||
assert request.content == "New content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_write_error(
|
||||
self, mock_connection: MockConnection, tmp_path: Path
|
||||
) -> None:
|
||||
mock_connection._write_error = RuntimeError("Permission denied")
|
||||
|
||||
tool = WriteFile(
|
||||
config=WriteFileConfig(workdir=tmp_path),
|
||||
state=AcpWriteFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id="test_session",
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
test_file = tmp_path / "test.txt"
|
||||
args = WriteFileArgs(path=str(test_file), content="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert str(exc_info.value) == f"Error writing {test_file}: Permission denied"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_connection(self, tmp_path: Path) -> None:
|
||||
tool = WriteFile(
|
||||
config=WriteFileConfig(workdir=tmp_path),
|
||||
state=AcpWriteFileState.model_construct(
|
||||
connection=None, session_id="test_session", tool_call_id="test_call"
|
||||
),
|
||||
)
|
||||
|
||||
args = WriteFileArgs(path=str(tmp_path / "test.txt"), content="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Connection not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_without_session_id(self, tmp_path: Path) -> None:
|
||||
mock_connection = MockConnection()
|
||||
tool = WriteFile(
|
||||
config=WriteFileConfig(workdir=tmp_path),
|
||||
state=AcpWriteFileState.model_construct(
|
||||
connection=mock_connection, # type: ignore[arg-type]
|
||||
session_id=None,
|
||||
tool_call_id="test_call",
|
||||
),
|
||||
)
|
||||
|
||||
args = WriteFileArgs(path=str(tmp_path / "test.txt"), content="test")
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
await tool.run(args)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "Session ID not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
|
||||
class TestAcpWriteFileSessionUpdates:
|
||||
def test_tool_call_session_update(self) -> None:
|
||||
event = ToolCallEvent(
|
||||
tool_name="write_file",
|
||||
tool_call_id="test_call_123",
|
||||
args=WriteFileArgs(path="/tmp/test.txt", content="Hello"),
|
||||
tool_class=WriteFile,
|
||||
)
|
||||
|
||||
update = WriteFile.tool_call_session_update(event)
|
||||
assert update is not None
|
||||
assert update.sessionUpdate == "tool_call"
|
||||
assert update.toolCallId == "test_call_123"
|
||||
assert update.kind == "edit"
|
||||
assert update.title is not None
|
||||
assert update.content is not None
|
||||
assert len(update.content) == 1
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "/tmp/test.txt"
|
||||
assert update.content[0].oldText is None
|
||||
assert update.content[0].newText == "Hello"
|
||||
assert update.locations is not None
|
||||
assert len(update.locations) == 1
|
||||
assert update.locations[0].path == "/tmp/test.txt"
|
||||
|
||||
def test_tool_call_session_update_invalid_args(self) -> None:
|
||||
from vibe.core.types import FunctionCall, ToolCall
|
||||
|
||||
class InvalidArgs:
|
||||
pass
|
||||
|
||||
event = ToolCallEvent.model_construct(
|
||||
tool_name="write_file",
|
||||
tool_call_id="test_call_123",
|
||||
args=InvalidArgs(), # type: ignore[arg-type]
|
||||
tool_class=WriteFile,
|
||||
llm_tool_call=ToolCall(
|
||||
function=FunctionCall(name="write_file", arguments="{}"),
|
||||
type="function",
|
||||
index=0,
|
||||
),
|
||||
)
|
||||
|
||||
update = WriteFile.tool_call_session_update(event)
|
||||
assert update is None
|
||||
|
||||
def test_tool_result_session_update(self) -> None:
|
||||
result = WriteFileResult(
|
||||
path="/tmp/test.txt", content="Hello", bytes_written=5, file_existed=False
|
||||
)
|
||||
|
||||
event = ToolResultEvent(
|
||||
tool_name="write_file",
|
||||
tool_call_id="test_call_123",
|
||||
result=result,
|
||||
tool_class=WriteFile,
|
||||
)
|
||||
|
||||
update = WriteFile.tool_result_session_update(event)
|
||||
assert update is not None
|
||||
assert update.sessionUpdate == "tool_call_update"
|
||||
assert update.toolCallId == "test_call_123"
|
||||
assert update.status == "completed"
|
||||
assert update.content is not None
|
||||
assert len(update.content) == 1
|
||||
assert update.content[0].type == "diff"
|
||||
assert update.content[0].path == "/tmp/test.txt"
|
||||
assert update.content[0].oldText is None
|
||||
assert update.content[0].newText == "Hello"
|
||||
assert update.locations is not None
|
||||
assert len(update.locations) == 1
|
||||
assert update.locations[0].path == "/tmp/test.txt"
|
||||
|
||||
def test_tool_result_session_update_invalid_result(self) -> None:
|
||||
class InvalidResult:
|
||||
pass
|
||||
|
||||
event = ToolResultEvent.model_construct(
|
||||
tool_name="write_file",
|
||||
tool_call_id="test_call_123",
|
||||
result=InvalidResult(), # type: ignore[arg-type]
|
||||
tool_class=WriteFile,
|
||||
)
|
||||
|
||||
update = WriteFile.tool_result_session_update(event)
|
||||
assert update is None
|
||||
231
tests/autocompletion/test_file_indexer.py
Normal file
231
tests/autocompletion/test_file_indexer.py
Normal file
@@ -0,0 +1,231 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.autocompletion.file_indexer import FileIndexer
|
||||
|
||||
# This suite runs against the real filesystem and watcher. A faked store/watcher
|
||||
# split would be faster to unit-test, but given time constraints and the low churn
|
||||
# expected for this feature, integration coverage was chosen as a trade-off.
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def file_indexer() -> Generator[FileIndexer]:
|
||||
indexer = FileIndexer()
|
||||
yield indexer
|
||||
indexer.shutdown()
|
||||
|
||||
|
||||
def _wait_for(condition: Callable[[], bool], timeout=3.0) -> bool:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
if condition():
|
||||
return True
|
||||
time.sleep(0.05)
|
||||
return False
|
||||
|
||||
|
||||
def test_updates_index_on_file_creation(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
target = tmp_path / "new_file.py"
|
||||
target.write_text("", encoding="utf-8")
|
||||
|
||||
assert _wait_for(
|
||||
lambda: any(
|
||||
entry.rel == target.name for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_updates_index_on_file_deletion(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
target = tmp_path / "new_file.py"
|
||||
target.write_text("", encoding="utf-8")
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
target.unlink()
|
||||
|
||||
assert _wait_for(
|
||||
lambda: all(
|
||||
entry.rel != target.name for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_updates_index_on_file_rename(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
old_file = tmp_path / "old_name.py"
|
||||
old_file.write_text("", encoding="utf-8")
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
new_file = tmp_path / "new_name.py"
|
||||
old_file.rename(new_file)
|
||||
|
||||
assert _wait_for(
|
||||
lambda: all(
|
||||
entry.rel != old_file.name for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
and any(
|
||||
entry.rel == new_file.name for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_updates_index_on_folder_rename(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
old_folder = tmp_path / "old_folder"
|
||||
old_folder.mkdir()
|
||||
number_of_files = 5
|
||||
file_names = [f"file{i}.py" for i in range(1, number_of_files + 1)]
|
||||
old_file_paths = [old_folder / name for name in file_names]
|
||||
for file_path in old_file_paths:
|
||||
file_path.write_text("", encoding="utf-8")
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
new_folder = tmp_path / "new_folder"
|
||||
old_folder.rename(new_folder)
|
||||
|
||||
assert _wait_for(
|
||||
lambda: (
|
||||
entries := file_indexer.get_index(Path(".")),
|
||||
all(not entry.rel.startswith("old_folder/") for entry in entries)
|
||||
and all(
|
||||
any(entry.rel == f"new_folder/{name}" for entry in entries)
|
||||
for name in file_names
|
||||
),
|
||||
)[1]
|
||||
)
|
||||
|
||||
|
||||
def test_updates_index_incrementally_by_default(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
rebuilds_before = file_indexer.stats.rebuilds
|
||||
incremental_before = file_indexer.stats.incremental_updates
|
||||
|
||||
target = tmp_path / "stats_file.py"
|
||||
target.write_text("", encoding="utf-8")
|
||||
|
||||
assert _wait_for(
|
||||
lambda: any(
|
||||
entry.rel == target.name for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
assert file_indexer.stats.rebuilds == rebuilds_before
|
||||
assert file_indexer.stats.incremental_updates >= incremental_before + 1
|
||||
|
||||
|
||||
def test_rebuilds_index_when_mass_change_threshold_is_exceeded(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
mass_change_threshold = 5
|
||||
# in an ideal world, we would use "threshold + 1", but in reality, we need to test with a
|
||||
# number of files important enough to MAKE SURE that a batch of >= threshold events will be
|
||||
# detected by the watcher
|
||||
number_of_files = mass_change_threshold * 3
|
||||
monkeypatch.chdir(tmp_path)
|
||||
indexer = FileIndexer(mass_change_threshold=mass_change_threshold)
|
||||
try:
|
||||
indexer.get_index(Path("."))
|
||||
rebuilds_before = indexer.stats.rebuilds
|
||||
|
||||
ThreadPoolExecutor(max_workers=number_of_files).map(
|
||||
lambda i: (tmp_path / f"bulk{i}.py").write_text("", encoding="utf-8"),
|
||||
range(number_of_files),
|
||||
)
|
||||
|
||||
assert _wait_for(lambda: len(indexer.get_index(Path("."))) == number_of_files)
|
||||
# we do not assert that "incremental_updates" did not change,
|
||||
# as the watcher potentially reported some batches of events that were
|
||||
# smaller than the threshold
|
||||
assert indexer.stats.rebuilds >= rebuilds_before + 1
|
||||
finally:
|
||||
indexer.shutdown()
|
||||
|
||||
|
||||
def test_switching_between_roots_restarts_index(
|
||||
tmp_path: Path,
|
||||
tmp_path_factory: pytest.TempPathFactory,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
file_indexer: FileIndexer,
|
||||
) -> None:
|
||||
first_root = tmp_path
|
||||
second_root = tmp_path_factory.mktemp("second-root")
|
||||
(first_root / "first.py").write_text("", encoding="utf-8")
|
||||
(second_root / "second.py").write_text("", encoding="utf-8")
|
||||
|
||||
monkeypatch.chdir(first_root)
|
||||
assert _wait_for(
|
||||
lambda: any(
|
||||
entry.rel == "first.py" for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
monkeypatch.chdir(second_root)
|
||||
assert _wait_for(
|
||||
lambda: all(
|
||||
entry.rel != "first.py" for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
assert _wait_for(
|
||||
lambda: any(
|
||||
entry.rel == "second.py" for entry in file_indexer.get_index(Path("."))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_watcher_failure_does_not_break_existing_index(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch, file_indexer: FileIndexer
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
seed = tmp_path / "seed.py"
|
||||
seed.write_text("", encoding="utf-8")
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
def boom(*_: object, **__: object) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(file_indexer._store, "apply_changes", boom)
|
||||
|
||||
(tmp_path / "new_file.py").write_text("", encoding="utf-8")
|
||||
|
||||
assert _wait_for(
|
||||
lambda: (
|
||||
entries := file_indexer.get_index(Path(".")),
|
||||
# new file was not added: watcher failed
|
||||
all(entry.rel != "new_file.py" for entry in entries)
|
||||
# but the existing index is still intact
|
||||
and all(entry.rel == "seed.py" for entry in entries),
|
||||
)[1]
|
||||
)
|
||||
|
||||
|
||||
def test_shutdown_cleans_up_resources(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.chdir(tmp_path)
|
||||
(tmp_path / "test.txt").write_text("", encoding="utf-8")
|
||||
file_indexer = FileIndexer()
|
||||
file_indexer.get_index(Path("."))
|
||||
|
||||
file_indexer.shutdown()
|
||||
assert file_indexer.get_index(Path(".")) == []
|
||||
96
tests/autocompletion/test_fuzzy.py
Normal file
96
tests/autocompletion/test_fuzzy.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from vibe.core.autocompletion.fuzzy import fuzzy_match
|
||||
|
||||
|
||||
def test_empty_pattern_matches_anything() -> None:
|
||||
result = fuzzy_match("", "any_text")
|
||||
|
||||
assert result.matched is True
|
||||
assert result.score == 0.0
|
||||
assert result.matched_indices == ()
|
||||
|
||||
|
||||
def test_matches_exact_prefix() -> None:
|
||||
result = fuzzy_match("src/", "src/main.py")
|
||||
|
||||
assert result.matched_indices == (0, 1, 2, 3)
|
||||
|
||||
|
||||
def test_no_match_when_characters_are_out_of_order() -> None:
|
||||
result = fuzzy_match("ms", "src/main.py")
|
||||
|
||||
assert result.matched is False
|
||||
|
||||
|
||||
def test_treats_consecutive_characters_as_subsequence() -> None:
|
||||
result = fuzzy_match("main", "src/main.py")
|
||||
|
||||
assert result.matched_indices == (4, 5, 6, 7)
|
||||
|
||||
|
||||
def test_ignores_case() -> None:
|
||||
result = fuzzy_match("SRC", "src/main.py")
|
||||
|
||||
assert result.matched_indices == (0, 1, 2)
|
||||
|
||||
|
||||
def test_treats_scattered_characters_as_subsequence() -> None:
|
||||
result = fuzzy_match("sm", "src/main.py")
|
||||
|
||||
assert result.matched_indices == (0, 4)
|
||||
|
||||
|
||||
def test_treats_path_separator_as_word_boundary() -> None:
|
||||
result = fuzzy_match("m", "src/main.py")
|
||||
|
||||
assert result.matched_indices == (4,)
|
||||
|
||||
|
||||
def test_prefers_word_boundary_matching_over_subsequence() -> None:
|
||||
boundary_result = fuzzy_match("ma", "src/main.py")
|
||||
subsequence_result = fuzzy_match("ma", "src/important.py")
|
||||
|
||||
assert boundary_result.score > subsequence_result.score
|
||||
|
||||
|
||||
def test_scores_exact_prefix_match_higher_than_consecutive_and_subsequence() -> None:
|
||||
prefix_result = fuzzy_match("src", "src/main.py")
|
||||
consecutive_result = fuzzy_match("main", "src/main.py")
|
||||
subsequence_result = fuzzy_match("sm", "src/main.py")
|
||||
|
||||
assert prefix_result.matched_indices == (0, 1, 2)
|
||||
assert prefix_result.score > consecutive_result.score
|
||||
assert prefix_result.score > subsequence_result.score
|
||||
|
||||
|
||||
def test_finds_no_match_when_pattern_is_longer_than_entry() -> None:
|
||||
result = fuzzy_match("very_long_pattern", "short")
|
||||
|
||||
assert result.matched is False
|
||||
|
||||
|
||||
def test_prefers_consecutive_match_over_subsequence() -> None:
|
||||
consecutive = fuzzy_match("main", "src/main.py")
|
||||
subsequence = fuzzy_match("mn", "src/main.py")
|
||||
|
||||
assert consecutive.score > subsequence.score
|
||||
|
||||
|
||||
def test_prefers_case_sensitive_match_over_case_insensitive() -> None:
|
||||
case_match = fuzzy_match("Main", "src/Main.py")
|
||||
case_insensitive_match = fuzzy_match("main", "src/Main.py")
|
||||
|
||||
assert case_match.score > case_insensitive_match.score
|
||||
|
||||
|
||||
def test_treats_uppercase_letter_as_word_boundary() -> None:
|
||||
result = fuzzy_match("MP", "src/MainPy.py")
|
||||
|
||||
assert result.matched_indices == (4, 8)
|
||||
|
||||
|
||||
def test_favors_earlier_positions() -> None:
|
||||
result = fuzzy_match("a", "banana")
|
||||
|
||||
assert result.matched_indices == (1,)
|
||||
122
tests/autocompletion/test_path_completer_fuzzy.py
Normal file
122
tests/autocompletion/test_path_completer_fuzzy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.autocompletion.completers import PathCompleter
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
(tmp_path / "src" / "utils").mkdir(parents=True)
|
||||
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "models.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core").mkdir(parents=True)
|
||||
(tmp_path / "src" / "core" / "logger.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "models.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "ports.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "sanitize.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "use_cases.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "validate.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("", encoding="utf-8")
|
||||
(tmp_path / ".env").write_text("", encoding="utf-8")
|
||||
(tmp_path / "config").mkdir(parents=True)
|
||||
(tmp_path / "config" / "settings.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "config" / "database.py").write_text("", encoding="utf-8")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_fuzzy_matches_subsequence_characters(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@sr", cursor_pos=3)
|
||||
|
||||
assert "@src/" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_consecutive_characters_higher(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
|
||||
|
||||
assert "@src/main.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_prefix_highest(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src", cursor_pos=4)
|
||||
|
||||
assert results[0].startswith("@src")
|
||||
|
||||
|
||||
def test_fuzzy_matches_across_directory_boundaries(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
|
||||
|
||||
assert "@src/main.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_case_insensitive(file_tree: Path) -> None:
|
||||
completer = PathCompleter()
|
||||
assert "@README.md" in completer.get_completions("@readme", cursor_pos=7)
|
||||
assert "@README.md" in completer.get_completions("@README", cursor_pos=7)
|
||||
|
||||
|
||||
def test_fuzzy_matches_word_boundaries_preferred(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/mp", cursor_pos=7)
|
||||
|
||||
assert "@src/models.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_empty_pattern_shows_all(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@", cursor_pos=1)
|
||||
|
||||
assert "@README.md" in results
|
||||
assert "@src/" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_hidden_files_only_with_dot(file_tree: Path) -> None:
|
||||
completer = PathCompleter()
|
||||
assert "@.env" not in completer.get_completions("@e", cursor_pos=2)
|
||||
assert "@.env" in completer.get_completions("@.", cursor_pos=2)
|
||||
|
||||
|
||||
def test_fuzzy_matches_directories_and_files(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/", cursor_pos=5)
|
||||
|
||||
assert any(r.endswith("/") for r in results)
|
||||
assert any(not r.endswith("/") for r in results)
|
||||
|
||||
|
||||
def test_fuzzy_matches_sorted_by_score(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/main", cursor_pos=9)
|
||||
|
||||
assert results[0] == "@src/main.py"
|
||||
|
||||
|
||||
def test_fuzzy_matches_nested_directories(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/core/l", cursor_pos=11)
|
||||
|
||||
assert "@src/core/logger.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_partial_filename(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/mo", cursor_pos=7)
|
||||
|
||||
assert "@src/models.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_multiple_files_with_same_pattern(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/m", cursor_pos=6)
|
||||
|
||||
assert "@src/main.py" in results
|
||||
assert "@src/models.py" in results
|
||||
|
||||
|
||||
def test_fuzzy_matches_no_results_when_no_match(file_tree: Path) -> None:
|
||||
completer = PathCompleter()
|
||||
assert completer.get_completions("@xyz123", cursor_pos=7) == []
|
||||
|
||||
|
||||
def test_fuzzy_matches_directory_traversal(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@src/", cursor_pos=5)
|
||||
|
||||
assert "@src/main.py" in results
|
||||
assert "@src/core/" in results
|
||||
assert "@src/utils/" in results
|
||||
69
tests/autocompletion/test_path_completer_recursive.py
Normal file
69
tests/autocompletion/test_path_completer_recursive.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.autocompletion.completers import PathCompleter
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
(tmp_path / "vibe" / "acp").mkdir(parents=True)
|
||||
(tmp_path / "vibe" / "acp" / "entrypoint.py").write_text("")
|
||||
(tmp_path / "vibe" / "acp" / "agent.py").write_text("")
|
||||
(tmp_path / "vibe" / "cli" / "autocompletion").mkdir(parents=True)
|
||||
(tmp_path / "vibe" / "cli" / "autocompletion" / "fuzzy.py").write_text("")
|
||||
(tmp_path / "vibe" / "cli" / "autocompletion" / "completers.py").write_text("")
|
||||
(tmp_path / "tests" / "autocompletion").mkdir(parents=True)
|
||||
(tmp_path / "tests" / "autocompletion" / "test_fuzzy.py").write_text("")
|
||||
(tmp_path / "README.md").write_text("")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def test_finds_files_recursively_by_filename(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@entryp", cursor_pos=7)
|
||||
|
||||
assert results[0] == "@vibe/acp/entrypoint.py"
|
||||
|
||||
|
||||
def test_finds_files_recursively_by_partial_path(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@acp/entry", cursor_pos=10)
|
||||
|
||||
assert results[0] == "@vibe/acp/entrypoint.py"
|
||||
|
||||
|
||||
def test_finds_files_recursively_with_subsequence(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@acp/ent", cursor_pos=9)
|
||||
|
||||
assert results[0] == "@vibe/acp/entrypoint.py"
|
||||
|
||||
|
||||
def test_finds_multiple_matches_recursively(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@fuzzy", cursor_pos=6)
|
||||
|
||||
vibe_index = results.index("@vibe/cli/autocompletion/fuzzy.py")
|
||||
test_index = results.index("@tests/autocompletion/test_fuzzy.py")
|
||||
assert vibe_index < test_index
|
||||
|
||||
|
||||
def test_prioritizes_exact_path_matches(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@vibe/acp/entrypoint", cursor_pos=20)
|
||||
|
||||
assert results[0] == "@vibe/acp/entrypoint.py"
|
||||
|
||||
|
||||
def test_finds_files_when_pattern_matches_directory_name(file_tree: Path) -> None:
|
||||
results = PathCompleter().get_completions("@acp", cursor_pos=4)
|
||||
|
||||
assert results == [
|
||||
"@vibe/acp/",
|
||||
"@vibe/acp/agent.py",
|
||||
"@vibe/acp/entrypoint.py",
|
||||
"@vibe/cli/autocompletion/completers.py",
|
||||
"@tests/autocompletion/",
|
||||
"@tests/autocompletion/test_fuzzy.py",
|
||||
"@vibe/cli/autocompletion/",
|
||||
"@vibe/cli/autocompletion/fuzzy.py",
|
||||
]
|
||||
258
tests/autocompletion/test_path_completion_controller.py
Normal file
258
tests/autocompletion/test_path_completion_controller.py
Normal file
@@ -0,0 +1,258 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from textual import events
|
||||
|
||||
from vibe.cli.autocompletion.base import CompletionResult, CompletionView
|
||||
from vibe.cli.autocompletion.path_completion import PathCompletionController
|
||||
from vibe.core.autocompletion.completers import PathCompleter
|
||||
|
||||
|
||||
class StubView(CompletionView):
|
||||
def __init__(self) -> None:
|
||||
self.suggestions: list[tuple[list[tuple[str, str]], int]] = []
|
||||
self.clears = 0
|
||||
self.replacements: list[tuple[int, int, str]] = []
|
||||
|
||||
def render_completion_suggestions(
|
||||
self, suggestions: list[tuple[str, str]], selected_index: int
|
||||
) -> None:
|
||||
self.suggestions.append((suggestions, selected_index))
|
||||
|
||||
def clear_completion_suggestions(self) -> None:
|
||||
self.clears += 1
|
||||
|
||||
def replace_completion_range(self, start: int, end: int, replacement: str) -> None:
|
||||
self.replacements.append((start, end, replacement))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
(tmp_path / "src" / "utils").mkdir(parents=True)
|
||||
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core").mkdir(parents=True)
|
||||
(tmp_path / "src" / "core" / "logger.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "models.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "ports.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "sanitize.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "use_cases.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "core" / "validate.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("", encoding="utf-8")
|
||||
(tmp_path / ".env").write_text("", encoding="utf-8")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def make_controller(
|
||||
max_entries_to_process: int | None = None, target_matches: int | None = None
|
||||
) -> tuple[PathCompletionController, StubView]:
|
||||
completer_kwargs = {}
|
||||
if max_entries_to_process is not None:
|
||||
completer_kwargs["max_entries_to_process"] = max_entries_to_process
|
||||
if target_matches is not None:
|
||||
completer_kwargs["target_matches"] = target_matches
|
||||
|
||||
completer = PathCompleter(**completer_kwargs)
|
||||
view = StubView()
|
||||
controller = PathCompletionController(completer, view)
|
||||
return controller, view
|
||||
|
||||
|
||||
def test_lists_root_entries(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@", cursor_index=1)
|
||||
|
||||
suggestions, selected = view.suggestions[-1]
|
||||
assert selected == 0
|
||||
assert [alias for alias, _ in suggestions] == ["@README.md", "@src/"]
|
||||
|
||||
|
||||
def test_suggests_hidden_entries_only_with_dot_prefix(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@.", cursor_index=2)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert suggestions[0][0] == "@.env"
|
||||
|
||||
|
||||
def test_lists_nested_entries_when_prefixing_with_folder_name(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@src/", cursor_index=5)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert [alias for alias, _ in suggestions] == [
|
||||
"@src/core/",
|
||||
"@src/main.py",
|
||||
"@src/utils/",
|
||||
]
|
||||
|
||||
|
||||
def test_resets_when_fragment_invalid(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@src", cursor_index=4)
|
||||
assert view.suggestions
|
||||
|
||||
controller.on_text_changed("@src foo", cursor_index=8)
|
||||
assert view.clears == 1
|
||||
assert (
|
||||
controller.on_key(events.Key("tab", None), "@src foo", 8)
|
||||
is CompletionResult.IGNORED
|
||||
)
|
||||
|
||||
|
||||
def test_applies_selected_completion_on_tab_keycode(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@R", cursor_index=2)
|
||||
result = controller.on_key(events.Key("tab", None), "@R", 2)
|
||||
|
||||
assert result is CompletionResult.HANDLED
|
||||
assert view.replacements == [(0, 2, "@README.md")]
|
||||
assert view.clears == 1
|
||||
|
||||
|
||||
def test_applies_selected_completion_on_enter_keycode(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
controller.on_text_changed("@src/", cursor_index=5)
|
||||
controller.on_key(events.Key("down", None), "@src/", 5)
|
||||
|
||||
result = controller.on_key(events.Key("enter", None), "@src/", 5)
|
||||
|
||||
assert result is CompletionResult.HANDLED
|
||||
assert view.replacements == [(0, 5, "@src/main.py")]
|
||||
assert view.clears == 1
|
||||
|
||||
|
||||
def test_navigates_and_cycles_across_suggestions(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@src/", cursor_index=5)
|
||||
controller.on_key(events.Key("down", None), "@src/", 5)
|
||||
suggestions, selected_index = view.suggestions[-1]
|
||||
assert [alias for alias, _ in suggestions] == [
|
||||
"@src/core/",
|
||||
"@src/main.py",
|
||||
"@src/utils/",
|
||||
]
|
||||
assert selected_index == 1
|
||||
controller.on_key(events.Key("up", None), "@src/", 5)
|
||||
suggestions, selected_index = view.suggestions[-1]
|
||||
assert selected_index == 0
|
||||
|
||||
controller.on_key(events.Key("down", None), "@src/", 5)
|
||||
controller.on_key(events.Key("down", None), "@src/", 5)
|
||||
suggestions, selected_index = view.suggestions[-1]
|
||||
assert selected_index == 2
|
||||
|
||||
controller.on_key(events.Key("down", None), "@src/", 5)
|
||||
suggestions, selected_index = view.suggestions[-1]
|
||||
assert selected_index == 0
|
||||
|
||||
|
||||
def test_limits_suggestions_to_ten(file_tree: Path) -> None:
|
||||
(file_tree / "src" / "core" / "extra").mkdir(parents=True)
|
||||
[
|
||||
(file_tree / "src" / "core" / "extra" / f"extra_file_{i}.py").write_text(
|
||||
"", encoding="utf-8"
|
||||
)
|
||||
for i in range(1, 13)
|
||||
]
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@src/core/extra/", cursor_index=16)
|
||||
suggestions, selected_index = view.suggestions[-1]
|
||||
assert len(suggestions) == 10
|
||||
assert [alias for alias, _ in suggestions] == [
|
||||
"@src/core/extra/extra_file_1.py",
|
||||
"@src/core/extra/extra_file_10.py",
|
||||
"@src/core/extra/extra_file_11.py",
|
||||
"@src/core/extra/extra_file_12.py",
|
||||
"@src/core/extra/extra_file_2.py",
|
||||
"@src/core/extra/extra_file_3.py",
|
||||
"@src/core/extra/extra_file_4.py",
|
||||
"@src/core/extra/extra_file_5.py",
|
||||
"@src/core/extra/extra_file_6.py",
|
||||
"@src/core/extra/extra_file_7.py",
|
||||
]
|
||||
assert selected_index == 0
|
||||
|
||||
|
||||
def test_does_not_handle_when_cursor_at_beginning_of_input(file_tree: Path) -> None:
|
||||
controller, _ = make_controller()
|
||||
|
||||
assert not controller.can_handle("@file", cursor_index=0)
|
||||
assert not controller.can_handle("", cursor_index=0)
|
||||
assert not controller.can_handle("hello@file", cursor_index=0)
|
||||
|
||||
|
||||
def test_does_not_handle_when_cursor_before_or_at_the_at_symbol(
|
||||
file_tree: Path,
|
||||
) -> None:
|
||||
controller, _ = make_controller()
|
||||
|
||||
assert not controller.can_handle("@file", cursor_index=0)
|
||||
assert not controller.can_handle("hello@file", cursor_index=5)
|
||||
|
||||
|
||||
def test_does_handle_when_cursor_after_the_at_symbol_even_in_the_middle_of_the_input(
|
||||
file_tree: Path,
|
||||
) -> None:
|
||||
controller, _ = make_controller()
|
||||
|
||||
assert controller.can_handle("@file", cursor_index=1)
|
||||
assert controller.can_handle("hello @file", cursor_index=7)
|
||||
|
||||
|
||||
def test_lists_immediate_children_when_path_ends_with_slash(file_tree: Path) -> None:
|
||||
controller, view = make_controller()
|
||||
|
||||
controller.on_text_changed("@src/", cursor_index=5)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert [alias for alias, _ in suggestions] == [
|
||||
"@src/core/",
|
||||
"@src/main.py",
|
||||
"@src/utils/",
|
||||
]
|
||||
|
||||
|
||||
def test_respects_max_entries_to_process_limit(file_tree: Path) -> None:
|
||||
for i in range(30):
|
||||
(file_tree / f"file_{i:03d}.txt").write_text("", encoding="utf-8")
|
||||
|
||||
controller, view = make_controller(max_entries_to_process=10)
|
||||
|
||||
controller.on_text_changed("@", cursor_index=1)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert len(suggestions) <= 10
|
||||
|
||||
|
||||
def test_respects_target_matches_limit_for_listing(file_tree: Path) -> None:
|
||||
for i in range(30):
|
||||
(file_tree / f"item_{i:03d}.txt").write_text("", encoding="utf-8")
|
||||
|
||||
controller, view = make_controller(target_matches=5)
|
||||
|
||||
controller.on_text_changed("@", cursor_index=1)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert len(suggestions) <= 5
|
||||
|
||||
|
||||
def test_respects_target_matches_limit_for_fuzzy_search(file_tree: Path) -> None:
|
||||
for i in range(30):
|
||||
(file_tree / f"test_file_{i:03d}.py").write_text("", encoding="utf-8")
|
||||
|
||||
controller, view = make_controller(target_matches=5)
|
||||
|
||||
controller.on_text_changed("@test", cursor_index=5)
|
||||
|
||||
suggestions, _ = view.suggestions[-1]
|
||||
assert len(suggestions) <= 5
|
||||
142
tests/autocompletion/test_path_prompt_transformer.py
Normal file
142
tests/autocompletion/test_path_prompt_transformer.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from vibe.core.autocompletion.path_prompt_adapter import (
|
||||
DEFAULT_MAX_EMBED_BYTES,
|
||||
render_path_prompt,
|
||||
)
|
||||
|
||||
|
||||
def test_treats_paths_to_files_as_embedded_resources(tmp_path: Path) -> None:
|
||||
readme = tmp_path / "README.md"
|
||||
readme.write_text("hello", encoding="utf-8")
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir()
|
||||
main_py = src_dir / "main.py"
|
||||
main_py.write_text("print('hi')", encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"Please review @README.md and @src/main.py",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
expected = (
|
||||
f"Please review README.md and src/main.py\n\n"
|
||||
f"{readme.as_uri()}\n```\nhello\n```\n\n"
|
||||
f"{main_py.as_uri()}\n```\nprint('hi')\n```"
|
||||
)
|
||||
assert rendered == expected
|
||||
|
||||
|
||||
def test_treats_path_to_directory_as_resource_links(tmp_path: Path) -> None:
|
||||
docs_dir = tmp_path / "docs"
|
||||
docs_dir.mkdir()
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"See @docs/ for details",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
expected = f"See docs/ for details\n\nuri: {docs_dir.as_uri()}\nname: docs/"
|
||||
assert rendered == expected
|
||||
|
||||
|
||||
def test_keeps_emails_and_embeds_paths(tmp_path: Path) -> None:
|
||||
readme = tmp_path / "README.md"
|
||||
readme.write_text("hello", encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"Contact user@example.com about @README.md",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
expected = (
|
||||
f"Contact user@example.com about README.md\n\n"
|
||||
f"{readme.as_uri()}\n```\nhello\n```"
|
||||
)
|
||||
assert rendered == expected
|
||||
|
||||
|
||||
def test_ignores_nonexistent_paths(tmp_path: Path) -> None:
|
||||
rendered = render_path_prompt(
|
||||
"Missing @nope.txt here",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
assert rendered == "Missing @nope.txt here"
|
||||
|
||||
|
||||
def test_falls_back_to_link_for_binary_files(tmp_path: Path) -> None:
|
||||
binary_path = tmp_path / "image.bin"
|
||||
binary_path.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"Inspect @image.bin", base_dir=tmp_path, max_embed_bytes=DEFAULT_MAX_EMBED_BYTES
|
||||
)
|
||||
|
||||
assert (
|
||||
rendered == f"Inspect image.bin\n\nuri: {binary_path.as_uri()}\nname: image.bin"
|
||||
)
|
||||
|
||||
|
||||
def test_excludes_supposed_binary_files_quickly_before_reading_content(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
zip_like = tmp_path / "archive.zip"
|
||||
zip_like.write_text("text content inside but treated as binary", encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"Inspect @archive.zip",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
assert (
|
||||
rendered
|
||||
== f"Inspect archive.zip\n\nuri: {zip_like.as_uri()}\nname: archive.zip"
|
||||
)
|
||||
|
||||
|
||||
def test_applies_max_embed_size_guard(tmp_path: Path) -> None:
|
||||
large_file = tmp_path / "big.txt"
|
||||
large_file.write_text("a" * 50, encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"Review @big.txt", base_dir=tmp_path, max_embed_bytes=10
|
||||
)
|
||||
|
||||
assert rendered == f"Review big.txt\n\nuri: {large_file.as_uri()}\nname: big.txt"
|
||||
|
||||
|
||||
def test_parses_paths_with_special_characters_when_quoted(tmp_path: Path) -> None:
|
||||
weird = tmp_path / "weird name(1).txt"
|
||||
weird.write_text("odd", encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
'Open @"weird name(1).txt"',
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
assert rendered == f"Open weird name(1).txt\n\n{weird.as_uri()}\n```\nodd\n```"
|
||||
|
||||
|
||||
def test_deduplicates_identical_paths(tmp_path: Path) -> None:
|
||||
readme = tmp_path / "README.md"
|
||||
readme.write_text("hello", encoding="utf-8")
|
||||
|
||||
rendered = render_path_prompt(
|
||||
"See @README.md and again @README.md",
|
||||
base_dir=tmp_path,
|
||||
max_embed_bytes=DEFAULT_MAX_EMBED_BYTES,
|
||||
)
|
||||
|
||||
assert (
|
||||
rendered
|
||||
== f"See README.md and again README.md\n\n{readme.as_uri()}\n```\nhello\n```"
|
||||
)
|
||||
162
tests/autocompletion/test_slash_command_controller.py
Normal file
162
tests/autocompletion/test_slash_command_controller.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from textual import events
|
||||
|
||||
from vibe.cli.autocompletion.base import CompletionResult, CompletionView
|
||||
from vibe.cli.autocompletion.slash_command import SlashCommandController
|
||||
from vibe.core.autocompletion.completers import CommandCompleter
|
||||
|
||||
|
||||
class Suggestion(NamedTuple):
|
||||
alias: str
|
||||
description: str
|
||||
|
||||
|
||||
class SuggestionEvent(NamedTuple):
|
||||
suggestions: list[Suggestion]
|
||||
selected_index: int
|
||||
|
||||
|
||||
class Replacement(NamedTuple):
|
||||
start: int
|
||||
end: int
|
||||
replacement: str
|
||||
|
||||
|
||||
class StubView(CompletionView):
|
||||
def __init__(self) -> None:
|
||||
self.suggestion_events: list[SuggestionEvent] = []
|
||||
self.reset_count = 0
|
||||
self.replacements: list[Replacement] = []
|
||||
|
||||
def render_completion_suggestions(
|
||||
self, suggestions: list[tuple[str, str]], selected_index: int
|
||||
) -> None:
|
||||
typed = [Suggestion(alias, description) for alias, description in suggestions]
|
||||
self.suggestion_events.append(SuggestionEvent(typed, selected_index))
|
||||
|
||||
def clear_completion_suggestions(self) -> None:
|
||||
self.reset_count += 1
|
||||
|
||||
def replace_completion_range(self, start: int, end: int, replacement: str) -> None:
|
||||
self.replacements.append(Replacement(start, end, replacement))
|
||||
|
||||
|
||||
def key_event(key: str) -> events.Key:
|
||||
return events.Key(key, character=None)
|
||||
|
||||
|
||||
def make_controller(
|
||||
*, prefix: str | None = None
|
||||
) -> tuple[SlashCommandController, StubView]:
|
||||
commands = [
|
||||
("/config", "Show current configuration"),
|
||||
("/compact", "Compact history"),
|
||||
("/help", "Display help"),
|
||||
("/config", "Override description"),
|
||||
("/summarize", "Summarize history"),
|
||||
("/logpath", "Show log path"),
|
||||
("/exit", "Exit application"),
|
||||
("/vim", "Toggle vim keybindings"),
|
||||
]
|
||||
completer = CommandCompleter(commands)
|
||||
view = StubView()
|
||||
controller = SlashCommandController(completer, view)
|
||||
|
||||
if prefix is not None:
|
||||
controller.on_text_changed(prefix, cursor_index=len(prefix))
|
||||
view.suggestion_events.clear()
|
||||
|
||||
return controller, view
|
||||
|
||||
|
||||
def test_on_text_change_emits_matching_suggestions_in_insertion_order_and_ignores_duplicates() -> (
|
||||
None
|
||||
):
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
controller.on_text_changed("/c", cursor_index=2)
|
||||
|
||||
suggestions, selected = view.suggestion_events[-1]
|
||||
assert suggestions == [
|
||||
Suggestion("/config", "Override description"),
|
||||
Suggestion("/compact", "Compact history"),
|
||||
]
|
||||
assert selected == 0
|
||||
|
||||
|
||||
def test_on_text_change_filters_suggestions_case_insensitively() -> None:
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
controller.on_text_changed("/CO", cursor_index=3)
|
||||
|
||||
suggestions, _ = view.suggestion_events[-1]
|
||||
assert [suggestion.alias for suggestion in suggestions] == ["/config", "/compact"]
|
||||
|
||||
|
||||
def test_on_text_change_clears_suggestions_when_no_matches() -> None:
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
controller.on_text_changed("/c", cursor_index=2)
|
||||
controller.on_text_changed("config", cursor_index=6)
|
||||
|
||||
assert view.reset_count >= 1
|
||||
|
||||
|
||||
def test_on_text_change_limits_the_number_of_results_to_five_and_preserve_insertion_order() -> (
|
||||
None
|
||||
):
|
||||
controller, view = make_controller(prefix="/")
|
||||
|
||||
controller.on_text_changed("/", cursor_index=1)
|
||||
|
||||
suggestions, selected_index = view.suggestion_events[-1]
|
||||
assert len(suggestions) == 5
|
||||
assert [suggestion.alias for suggestion in suggestions] == [
|
||||
"/config",
|
||||
"/compact",
|
||||
"/help",
|
||||
"/summarize",
|
||||
"/logpath",
|
||||
]
|
||||
|
||||
|
||||
def test_on_key_tab_applies_selected_completion() -> None:
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
result = controller.on_key(key_event("tab"), text="/c", cursor_index=2)
|
||||
|
||||
assert result is CompletionResult.HANDLED
|
||||
assert view.replacements == [Replacement(0, 2, "/config")]
|
||||
assert view.reset_count == 1
|
||||
|
||||
|
||||
def test_on_key_down_and_up_cycle_selection() -> None:
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
controller.on_key(key_event("down"), text="/c", cursor_index=2)
|
||||
suggestions, selected_index = view.suggestion_events[-1]
|
||||
assert selected_index == 1
|
||||
|
||||
controller.on_key(key_event("down"), text="/c", cursor_index=2)
|
||||
suggestions, selected_index = view.suggestion_events[-1]
|
||||
assert selected_index == 0
|
||||
|
||||
controller.on_key(key_event("up"), text="/c", cursor_index=2)
|
||||
suggestions, selected_index = view.suggestion_events[-1]
|
||||
assert selected_index == 1
|
||||
assert [suggestion.alias for suggestion in suggestions] == ["/config", "/compact"]
|
||||
|
||||
|
||||
def test_on_key_enter_submits_selected_completion() -> None:
|
||||
controller, view = make_controller(prefix="/c")
|
||||
|
||||
controller.on_key(key_event("down"), text="/c", cursor_index=2)
|
||||
|
||||
result = controller.on_key(key_event("enter"), text="/c", cursor_index=2)
|
||||
|
||||
assert result is CompletionResult.SUBMIT
|
||||
assert view.replacements == [Replacement(0, 2, "/compact")]
|
||||
assert view.reset_count == 1
|
||||
306
tests/autocompletion/test_ui_chat_autocompletion.py
Normal file
306
tests/autocompletion/test_ui_chat_autocompletion.py
Normal file
@@ -0,0 +1,306 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from textual.content import Content
|
||||
from textual.style import Style
|
||||
from textual.widgets import Markdown
|
||||
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.textual_ui.widgets.chat_input.completion_popup import CompletionPopup
|
||||
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config() -> VibeConfig:
|
||||
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
|
||||
return VibeApp(config=vibe_config)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_appears_with_matching_suggestions(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/sum")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert popup.styles.display == "block"
|
||||
assert "/summarize" in popup_content
|
||||
assert "Compact conversation history by summarizing" in popup_content
|
||||
assert chat_input.value == "/sum"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_popup_hides_when_input_cleared(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/c")
|
||||
await pilot.press("backspace", "backspace")
|
||||
|
||||
assert popup.styles.display == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pressing_tab_writes_selected_command_and_keeps_popup_visible(
|
||||
vibe_app: VibeApp,
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/co")
|
||||
await pilot.press("tab")
|
||||
|
||||
assert chat_input.value == "/config"
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
def ensure_selected_command(popup: CompletionPopup, expected_alias: str) -> None:
|
||||
renderable = popup.render()
|
||||
assert isinstance(renderable, Content)
|
||||
content = str(renderable)
|
||||
|
||||
selected_aliases: list[str] = []
|
||||
for span in renderable.spans:
|
||||
style = span.style
|
||||
if isinstance(style, Style) and style.reverse:
|
||||
alias_text = content[span.start : span.end].strip()
|
||||
alias = alias_text.split()[0] if alias_text else ""
|
||||
selected_aliases.append(alias)
|
||||
|
||||
assert len(selected_aliases) == 1
|
||||
assert selected_aliases[0] == expected_alias
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arrow_navigation_updates_selected_suggestion(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/c")
|
||||
|
||||
ensure_selected_command(popup, "/cfg")
|
||||
await pilot.press("down")
|
||||
ensure_selected_command(popup, "/config")
|
||||
await pilot.press("up")
|
||||
ensure_selected_command(popup, "/cfg")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_arrow_navigation_cycles_through_suggestions(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/st")
|
||||
|
||||
ensure_selected_command(popup, "/stats")
|
||||
await pilot.press("down")
|
||||
ensure_selected_command(popup, "/status")
|
||||
await pilot.press("up")
|
||||
ensure_selected_command(popup, "/stats")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pressing_enter_submits_selected_command_and_hides_popup(
|
||||
vibe_app: VibeApp,
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"/hel") # typos:disable-line
|
||||
await pilot.press("enter")
|
||||
|
||||
assert chat_input.value == ""
|
||||
assert popup.styles.display == "none"
|
||||
message = vibe_app.query_one(".user-command-message")
|
||||
message_content = message.query_one(Markdown)
|
||||
assert "Show help message" in message_content.source
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def file_tree(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path:
|
||||
(tmp_path / "src" / "utils").mkdir(parents=True)
|
||||
(tmp_path / "src" / "utils" / "config.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "utils" / "database.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "utils" / "error_handling.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "utils" / "logger.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "utils" / "sanitize.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "utils" / "validate.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "src" / "main.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "vibe" / "acp").mkdir(parents=True)
|
||||
(tmp_path / "vibe" / "acp" / "entrypoint.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "vibe" / "acp" / "agent.py").write_text("", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("", encoding="utf-8")
|
||||
(tmp_path / ".env").write_text("", encoding="utf-8")
|
||||
monkeypatch.chdir(tmp_path)
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_completion_popup_lists_files_and_directories(
|
||||
vibe_app: VibeApp, file_tree: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@s")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@src/" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_path_completion_popup_shows_up_to_ten_results(
|
||||
vibe_app: VibeApp, file_tree: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
(file_tree / "src" / "core" / "extra").mkdir(parents=True)
|
||||
[
|
||||
(file_tree / "src" / "core" / "extra" / f"extra_file_{i}.py").write_text(
|
||||
"", encoding="utf-8"
|
||||
)
|
||||
for i in range(1, 13)
|
||||
]
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@src/core/extra/")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@src/core/extra/extra_file_1.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_10.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_11.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_12.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_2.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_3.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_4.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_5.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_6.py" in popup_content
|
||||
assert "@src/core/extra/extra_file_7.py" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pressing_tab_writes_selected_path_name_and_hides_popup(
|
||||
vibe_app: VibeApp, file_tree: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"Print @REA")
|
||||
await pilot.press("tab")
|
||||
|
||||
assert chat_input.value == "Print @README.md "
|
||||
assert popup.styles.display == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pressing_enter_writes_selected_path_name_and_hides_popup(
|
||||
vibe_app: VibeApp, file_tree: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"Print @src/m")
|
||||
await pilot.press("enter")
|
||||
|
||||
assert chat_input.value == "Print @src/main.py "
|
||||
assert popup.styles.display == "none"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fuzzy_matches_subsequence_characters(
|
||||
file_tree: Path, vibe_app: VibeApp
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@src/utils/handling")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@src/utils/error_handling.py" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fuzzy_matches_word_boundaries(
|
||||
file_tree: Path, vibe_app: VibeApp
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@src/utils/eh")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@src/utils/error_handling.py" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_files_recursively_by_filename(
|
||||
file_tree: Path, vibe_app: VibeApp
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@entryp")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@vibe/acp/entrypoint.py" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_files_recursively_with_partial_path(
|
||||
file_tree: Path, vibe_app: VibeApp
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
|
||||
await pilot.press(*"@acp/entry")
|
||||
|
||||
popup_content = str(popup.render())
|
||||
assert "@vibe/acp/entrypoint.py" in popup_content
|
||||
assert popup.styles.display == "block"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_trigger_completion_when_navigating_history(
|
||||
file_tree: Path, vibe_app: VibeApp
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
popup = vibe_app.query_one(CompletionPopup)
|
||||
message_with_path = "Check @src/m"
|
||||
message_to_fill_history = "Yet another message to fill history"
|
||||
|
||||
await pilot.press(*message_with_path)
|
||||
await pilot.press("tab", "enter")
|
||||
await pilot.press(*message_to_fill_history)
|
||||
await pilot.press("enter")
|
||||
await pilot.press("up", "up")
|
||||
assert chat_input.value == "Check @src/main.py"
|
||||
await pilot.pause(0.2)
|
||||
# ensure popup is hidden - user was navigating history: we don't want to interrupt
|
||||
assert popup.styles.display == "none"
|
||||
await pilot.press("down")
|
||||
await pilot.pause(0.1)
|
||||
assert popup.styles.display == "none"
|
||||
# get back to the message with path completion; ensure again
|
||||
await pilot.press("up")
|
||||
await pilot.pause(0.1)
|
||||
assert chat_input.value == "Check @src/main.py"
|
||||
await pilot.pause(0.2)
|
||||
assert popup.styles.display == "none"
|
||||
0
tests/backend/__init__.py
Normal file
0
tests/backend/__init__.py
Normal file
6
tests/backend/data/__init__.py
Normal file
6
tests/backend/data/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
Url = str
|
||||
JsonResponse = dict
|
||||
ResultData = dict
|
||||
Chunk = bytes
|
||||
183
tests/backend/data/fireworks.py
Normal file
183
tests/backend/data/fireworks.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
|
||||
|
||||
SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
{
|
||||
"id": "fake_id_1234",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "accounts/fireworks/models/glm-4p5",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Some content",
|
||||
"reasoning_content": "Some reasoning content",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
"prompt_tokens_details": {"cached_tokens": 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": "stop",
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
TOOL_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
{
|
||||
"id": "fake_id_1234",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "accounts/fireworks/models/glm-4p5",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"reasoning_content": "Some reasoning content",
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "fake_id_5678",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "some_tool",
|
||||
"arguments": '{"some_argument": "some_argument_value"}',
|
||||
},
|
||||
"name": None,
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 200,
|
||||
"prompt_tokens_details": {"cached_tokens": 0},
|
||||
},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": "tool_calls",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "some_tool",
|
||||
"arguments": '{"some_argument": "some_argument_value"}',
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
STREAMED_SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
[
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}],"usage":null}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"reasoning_content":"Some reasoning content"},"finish_reason":null}],"usage":null}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"usage":null}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"accounts/fireworks/models/glm-4p5","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200,"prompt_tokens_details":{"cached_tokens":0}}}',
|
||||
rb"data: [DONE]",
|
||||
],
|
||||
[
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
STREAMED_TOOL_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
[
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": null}],"usage": null}',
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"reasoning_content": "Some reasoning content"},"finish_reason": null}],"usage": null}',
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"content": "Some content"},"finish_reason": null}],"usage": null}',
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"tool_calls": [{"index": 0,"id": "fake_id_151617","type": "function","function": {"name": "some_tool"}}]},"finish_reason": null}],"usage": null}',
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0,"delta": {"tool_calls": [{"index": 0,"id": null,"type": "function","function": {"arguments": "{\"some_argument\": \"some_arguments_value\"}"}}]},"finish_reason": null}],"usage": null}',
|
||||
rb'data: {"id": "fake_id_1234","object": "chat.completion.chunk","created": 1234567890,"model": "accounts/fireworks/models/glm-4p5","choices": [{"index": 0, "delta": {}, "finish_reason": "tool_calls"}],"usage": {"prompt_tokens": 100,"total_tokens": 300,"completion_tokens": 200,"prompt_tokens_details": {"cached_tokens": 190}}}',
|
||||
rb"data: [DONE]",
|
||||
],
|
||||
[
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"tool_calls": [{"name": "some_tool", "arguments": None, "index": 0}],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": None,
|
||||
"arguments": '{"some_argument": "some_arguments_value"}',
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
173
tests/backend/data/mistral.py
Normal file
173
tests/backend/data/mistral.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
|
||||
|
||||
SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
{
|
||||
"id": "fake_id_1234",
|
||||
"created": 1234567890,
|
||||
"model": "devstral-latest",
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
},
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": None,
|
||||
"content": "Some content",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": "stop",
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
TOOL_CONVERSATION_PARAMS: list[tuple[Url, JsonResponse, ResultData]] = [
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
{
|
||||
"id": "fake_id_1234",
|
||||
"created": 1234567890,
|
||||
"model": "devstral-latest",
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
},
|
||||
"object": "chat.completion",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "tool_calls",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "fake_id_5678",
|
||||
"function": {
|
||||
"name": "some_tool",
|
||||
"arguments": '{"some_argument": "some_argument_value"}',
|
||||
},
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"content": "Some content",
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": "tool_calls",
|
||||
"tool_calls": [
|
||||
{
|
||||
"name": "some_tool",
|
||||
"arguments": '{"some_argument": "some_argument_value"}',
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
STREAMED_SIMPLE_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
[
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"p":"abcde"}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200},"p":"abcdefghijklmnopq"}',
|
||||
rb"data: [DONE]",
|
||||
],
|
||||
[
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": "stop",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
STREAMED_TOOL_CONVERSATION_PARAMS: list[tuple[Url, list[Chunk], list[ResultData]]] = [
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
[
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":"Some content"},"finish_reason":null}],"p":"a"}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"id":"fake_id_1234","function":{"name":"some_tool","arguments":""},"index":0}]},"finish_reason":null}],"p":"abcdef"}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"function":{"name":"","arguments":"{\"some_argument\": "},"index":0}]},"finish_reason":null}],"p":"abcdefghijklmnopq"}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"tool_calls":[{"id":"null","function":{"name":"","arguments":"\"some_argument_value\"}"},"index":0}]},"finish_reason":null}],"p":"abcdefghijklmnopqrstuvwxyz0123456"}',
|
||||
rb'data: {"id":"fake_id_1234","object":"chat.completion.chunk","created":1234567890,"model":"devstral-latest","choices":[{"index":0,"delta":{"content":""},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":100,"total_tokens":300,"completion_tokens":200},"p":"abcdefghijklmnopq"}',
|
||||
rb"data: [DONE]",
|
||||
],
|
||||
[
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "Some content",
|
||||
"finish_reason": None,
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"tool_calls": [{"name": "some_tool", "arguments": "", "index": 0}],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"tool_calls": [
|
||||
{"name": "", "arguments": '{"some_argument": ', "index": 0}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": None,
|
||||
"tool_calls": [
|
||||
{"name": "", "arguments": '"some_argument_value"}', "index": 0}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
|
||||
},
|
||||
{
|
||||
"message": "",
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 200},
|
||||
},
|
||||
],
|
||||
)
|
||||
]
|
||||
248
tests/backend/test_backend.py
Normal file
248
tests/backend/test_backend.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Test data for this module was generated using real LLM provider API responses,
|
||||
with responses simplified and formatted to make them readable and maintainable.
|
||||
|
||||
To update or modify test parameters:
|
||||
1. Make actual API calls to the target providers
|
||||
2. Use the raw API responses as a base for updating test data
|
||||
3. Simplify only where necessary for readability while preserving core structure
|
||||
|
||||
The closer test data remains to real API responses, the more reliable and accurate
|
||||
the tests will be. Always prefer real API data over manually constructed examples.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from tests.backend.data import Chunk, JsonResponse, ResultData, Url
|
||||
from tests.backend.data.fireworks import (
|
||||
SIMPLE_CONVERSATION_PARAMS as FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
||||
STREAMED_SIMPLE_CONVERSATION_PARAMS as FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
||||
STREAMED_TOOL_CONVERSATION_PARAMS as FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
||||
TOOL_CONVERSATION_PARAMS as FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
||||
)
|
||||
from tests.backend.data.mistral import (
|
||||
SIMPLE_CONVERSATION_PARAMS as MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
||||
STREAMED_SIMPLE_CONVERSATION_PARAMS as MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
||||
STREAMED_TOOL_CONVERSATION_PARAMS as MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
||||
TOOL_CONVERSATION_PARAMS as MISTRAL_TOOL_CONVERSATION_PARAMS,
|
||||
)
|
||||
from vibe.core.config import ModelConfig, ProviderConfig
|
||||
from vibe.core.llm.backend.generic import GenericBackend
|
||||
from vibe.core.llm.backend.mistral import MistralBackend
|
||||
from vibe.core.llm.exceptions import BackendError
|
||||
from vibe.core.llm.types import BackendLike
|
||||
from vibe.core.types import LLMChunk, LLMMessage, Role, ToolCall
|
||||
|
||||
|
||||
class TestBackend:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"base_url,json_response,result_data",
|
||||
[
|
||||
*FIREWORKS_SIMPLE_CONVERSATION_PARAMS,
|
||||
*FIREWORKS_TOOL_CONVERSATION_PARAMS,
|
||||
*MISTRAL_SIMPLE_CONVERSATION_PARAMS,
|
||||
*MISTRAL_TOOL_CONVERSATION_PARAMS,
|
||||
],
|
||||
)
|
||||
async def test_backend_complete(
|
||||
self, base_url: Url, json_response: JsonResponse, result_data: ResultData
|
||||
):
|
||||
with respx.mock(base_url=base_url) as mock_api:
|
||||
mock_api.post("/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(status_code=200, json=json_response)
|
||||
)
|
||||
provider = ProviderConfig(
|
||||
name="provider_name",
|
||||
api_base=f"{base_url}/v1",
|
||||
api_key_env_var="API_KEY",
|
||||
)
|
||||
|
||||
BackendClasses = [
|
||||
GenericBackend,
|
||||
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
||||
]
|
||||
for BackendClass in BackendClasses:
|
||||
backend: BackendLike = BackendClass(provider=provider)
|
||||
model = ModelConfig(
|
||||
name="model_name", provider="provider_name", alias="model_alias"
|
||||
)
|
||||
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
||||
|
||||
result = await backend.complete(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
tool_choice=None,
|
||||
extra_headers=None,
|
||||
)
|
||||
|
||||
assert result.message.content == result_data["message"]
|
||||
assert result.finish_reason == result_data["finish_reason"]
|
||||
assert result.usage is not None
|
||||
assert (
|
||||
result.usage.prompt_tokens == result_data["usage"]["prompt_tokens"]
|
||||
)
|
||||
assert (
|
||||
result.usage.completion_tokens
|
||||
== result_data["usage"]["completion_tokens"]
|
||||
)
|
||||
|
||||
if result.message.tool_calls is None:
|
||||
return
|
||||
|
||||
assert len(result.message.tool_calls) == len(result_data["tool_calls"])
|
||||
for i, tool_call in enumerate[ToolCall](result.message.tool_calls):
|
||||
assert (
|
||||
tool_call.function.name == result_data["tool_calls"][i]["name"]
|
||||
)
|
||||
assert (
|
||||
tool_call.function.arguments
|
||||
== result_data["tool_calls"][i]["arguments"]
|
||||
)
|
||||
assert tool_call.index == result_data["tool_calls"][i]["index"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"base_url,chunks,result_data",
|
||||
[
|
||||
*FIREWORKS_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
||||
*FIREWORKS_STREAMED_TOOL_CONVERSATION_PARAMS,
|
||||
*MISTRAL_STREAMED_SIMPLE_CONVERSATION_PARAMS,
|
||||
*MISTRAL_STREAMED_TOOL_CONVERSATION_PARAMS,
|
||||
],
|
||||
)
|
||||
async def test_backend_complete_streaming(
|
||||
self, base_url: Url, chunks: list[Chunk], result_data: list[ResultData]
|
||||
):
|
||||
with respx.mock(base_url=base_url) as mock_api:
|
||||
mock_api.post("/v1/chat/completions").mock(
|
||||
return_value=httpx.Response(
|
||||
status_code=200,
|
||||
stream=httpx.ByteStream(stream=b"\n\n".join(chunks)),
|
||||
headers={"Content-Type": "text/event-stream"},
|
||||
)
|
||||
)
|
||||
provider = ProviderConfig(
|
||||
name="provider_name",
|
||||
api_base=f"{base_url}/v1",
|
||||
api_key_env_var="API_KEY",
|
||||
)
|
||||
BackendClasses = [
|
||||
GenericBackend,
|
||||
*([MistralBackend] if base_url == "https://api.mistral.ai" else []),
|
||||
]
|
||||
for BackendClass in BackendClasses:
|
||||
backend: BackendLike = BackendClass(provider=provider)
|
||||
model = ModelConfig(
|
||||
name="model_name", provider="provider_name", alias="model_alias"
|
||||
)
|
||||
|
||||
messages = [
|
||||
LLMMessage(role=Role.user, content="List files in current dir")
|
||||
]
|
||||
|
||||
results: list[LLMChunk] = []
|
||||
async for result in backend.complete_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
tool_choice=None,
|
||||
extra_headers=None,
|
||||
):
|
||||
results.append(result)
|
||||
|
||||
for result, expected_result in zip(results, result_data, strict=True):
|
||||
assert result.message.content == expected_result["message"]
|
||||
assert result.finish_reason == expected_result["finish_reason"]
|
||||
assert result.usage is not None
|
||||
assert (
|
||||
result.usage.prompt_tokens
|
||||
== expected_result["usage"]["prompt_tokens"]
|
||||
)
|
||||
assert (
|
||||
result.usage.completion_tokens
|
||||
== expected_result["usage"]["completion_tokens"]
|
||||
)
|
||||
|
||||
if result.message.tool_calls is None:
|
||||
continue
|
||||
|
||||
for i, tool_call in enumerate(result.message.tool_calls):
|
||||
assert (
|
||||
tool_call.function.name
|
||||
== expected_result["tool_calls"][i]["name"]
|
||||
)
|
||||
assert (
|
||||
tool_call.function.arguments
|
||||
== expected_result["tool_calls"][i]["arguments"]
|
||||
)
|
||||
assert (
|
||||
tool_call.index == expected_result["tool_calls"][i]["index"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"base_url,backend_class,response",
|
||||
[
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
GenericBackend,
|
||||
httpx.Response(status_code=500, text="Internal Server Error"),
|
||||
),
|
||||
(
|
||||
"https://api.fireworks.ai",
|
||||
GenericBackend,
|
||||
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
||||
),
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
MistralBackend,
|
||||
httpx.Response(status_code=500, text="Internal Server Error"),
|
||||
),
|
||||
(
|
||||
"https://api.mistral.ai",
|
||||
MistralBackend,
|
||||
httpx.Response(status_code=429, text="Rate Limit Exceeded"),
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_backend_complete_streaming_error(
|
||||
self,
|
||||
base_url: Url,
|
||||
backend_class: type[MistralBackend | GenericBackend],
|
||||
response: httpx.Response,
|
||||
):
|
||||
with respx.mock(base_url=base_url) as mock_api:
|
||||
mock_api.post("/v1/chat/completions").mock(return_value=response)
|
||||
provider = ProviderConfig(
|
||||
name="provider_name",
|
||||
api_base=f"{base_url}/v1",
|
||||
api_key_env_var="API_KEY",
|
||||
)
|
||||
backend = backend_class(provider=provider)
|
||||
model = ModelConfig(
|
||||
name="model_name", provider="provider_name", alias="model_alias"
|
||||
)
|
||||
messages = [LLMMessage(role=Role.user, content="Just say hi")]
|
||||
with pytest.raises(BackendError) as e:
|
||||
async for _ in backend.complete_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
tool_choice=None,
|
||||
extra_headers=None,
|
||||
):
|
||||
pass
|
||||
assert e.value.status == response.status_code
|
||||
assert e.value.reason == response.reason_phrase
|
||||
assert e.value.parsed_error is None
|
||||
87
tests/conftest.py
Normal file
87
tests/conftest.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource
|
||||
import pytest
|
||||
|
||||
_in_mem_config: dict[str, Any] = {}
|
||||
|
||||
|
||||
class InMemSettingsSource(PydanticBaseSettingsSource):
|
||||
def __init__(self, settings_cls: type[BaseSettings]) -> None:
|
||||
super().__init__(settings_cls)
|
||||
|
||||
def get_field_value(
|
||||
self, field: FieldInfo, field_name: str
|
||||
) -> tuple[Any, str, bool]:
|
||||
return _in_mem_config.get(field_name), field_name, False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return _in_mem_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="session")
|
||||
def _patch_vibe_config() -> None:
|
||||
"""Patch VibeConfig.settings_customise_sources to only use init_settings in tests.
|
||||
|
||||
This ensures that even production code that creates VibeConfig instances
|
||||
will only use init_settings and ignore environment variables and config files.
|
||||
Runs once per test session before any tests execute.
|
||||
"""
|
||||
from vibe.core.config import VibeConfig
|
||||
|
||||
def patched_settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (init_settings, InMemSettingsSource(settings_cls))
|
||||
|
||||
VibeConfig.settings_customise_sources = classmethod(
|
||||
patched_settings_customise_sources
|
||||
) # type: ignore[assignment]
|
||||
|
||||
def dump_config(cls, config: dict[str, Any]) -> None:
|
||||
global _in_mem_config
|
||||
_in_mem_config = config
|
||||
|
||||
VibeConfig.dump_config = classmethod(dump_config) # type: ignore[assignment]
|
||||
|
||||
def patched_load(cls, agent: str | None = None, **overrides: Any) -> Any:
|
||||
return cls(**overrides)
|
||||
|
||||
VibeConfig.load = classmethod(patched_load) # type: ignore[assignment]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_in_mem_config() -> None:
|
||||
"""Reset in-memory config before each test to prevent test isolation issues.
|
||||
|
||||
This ensures that each test starts with a clean configuration state,
|
||||
preventing race conditions and test interference when tests run in parallel
|
||||
or when VibeConfig.save_updates() modifies the shared _in_mem_config dict.
|
||||
"""
|
||||
global _in_mem_config
|
||||
_in_mem_config = {}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "mock")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_platform(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mock platform to be Linux with /bin/sh shell for consistent test behavior.
|
||||
|
||||
This ensures that platform-specific system prompt generation is consistent
|
||||
across all tests regardless of the actual platform running the tests.
|
||||
"""
|
||||
monkeypatch.setattr(sys, "platform", "linux")
|
||||
monkeypatch.setenv("SHELL", "/bin/sh")
|
||||
51
tests/core/test_config_migration.py
Normal file
51
tests/core/test_config_migration.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import tomllib
|
||||
|
||||
import tomli_w
|
||||
|
||||
from vibe.core import config
|
||||
from vibe.core.config import VibeConfig
|
||||
|
||||
|
||||
def _restore_dump_config(config_file: Path):
|
||||
original_dump_config = VibeConfig.dump_config
|
||||
|
||||
def real_dump_config(cls, config_dict: dict) -> None:
|
||||
try:
|
||||
with config_file.open("wb") as f:
|
||||
tomli_w.dump(config_dict, f)
|
||||
except OSError:
|
||||
config_file.write_text(
|
||||
"\n".join(
|
||||
f"{k} = {v!r}" for k, v in config_dict.items() if v is not None
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
VibeConfig.dump_config = classmethod(real_dump_config) # type: ignore[assignment]
|
||||
return original_dump_config
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _migrate_config_file(tmp_path: Path, content: str):
|
||||
config_file = tmp_path / "config.toml"
|
||||
config_file.write_text(content, encoding="utf-8")
|
||||
|
||||
original_config_file = config.CONFIG_FILE
|
||||
original_dump_config = _restore_dump_config(config_file)
|
||||
|
||||
try:
|
||||
config.CONFIG_FILE = config_file
|
||||
VibeConfig._migrate()
|
||||
yield config_file
|
||||
finally:
|
||||
config.CONFIG_FILE = original_config_file
|
||||
VibeConfig.dump_config = original_dump_config
|
||||
|
||||
|
||||
def _load_migrated_config(config_file: Path) -> dict:
|
||||
with config_file.open("rb") as f:
|
||||
return tomllib.load(f)
|
||||
0
tests/mock/__init__.py
Normal file
0
tests/mock/__init__.py
Normal file
16
tests/mock/mock_backend_factory.py
Normal file
16
tests/mock/mock_backend_factory.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from vibe.core.config import Backend
|
||||
from vibe.core.llm.backend.factory import BACKEND_FACTORY
|
||||
|
||||
|
||||
@contextmanager
|
||||
def mock_backend_factory(backend_type: Backend, factory_func):
|
||||
original = BACKEND_FACTORY[backend_type]
|
||||
try:
|
||||
BACKEND_FACTORY[backend_type] = factory_func
|
||||
yield
|
||||
finally:
|
||||
BACKEND_FACTORY[backend_type] = original
|
||||
66
tests/mock/mock_entrypoint.py
Normal file
66
tests/mock/mock_entrypoint.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Wrapper script that intercepts LLM calls when mocking is enabled.
|
||||
|
||||
This script is used to mock the LLM calls when testing the CLI.
|
||||
Mocked returns are stored in the VIBE_MOCK_LLM_DATA environment variable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from tests import TESTS_ROOT
|
||||
from tests.mock.utils import MOCK_DATA_ENV_VAR
|
||||
from vibe.core.types import LLMChunk
|
||||
|
||||
|
||||
def mock_llm_output() -> None:
|
||||
sys.path.insert(0, str(TESTS_ROOT))
|
||||
|
||||
# Apply mocking before importing any vibe modules
|
||||
mock_data_str = os.environ.get(MOCK_DATA_ENV_VAR)
|
||||
if not mock_data_str:
|
||||
raise ValueError(f"{MOCK_DATA_ENV_VAR} is not set")
|
||||
mock_data = json.loads(mock_data_str)
|
||||
try:
|
||||
chunks = [LLMChunk.model_validate(chunk) for chunk in mock_data]
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid mock data: {e}") from e
|
||||
|
||||
chunk_iterable = iter(chunks)
|
||||
|
||||
async def mock_complete(*args, **kwargs) -> LLMChunk:
|
||||
return next(chunk_iterable)
|
||||
|
||||
async def mock_complete_streaming(*args, **kwargs) -> AsyncGenerator[LLMChunk]:
|
||||
yield next(chunk_iterable)
|
||||
|
||||
patch(
|
||||
"vibe.core.llm.backend.mistral.MistralBackend.complete",
|
||||
side_effect=mock_complete,
|
||||
).start()
|
||||
patch(
|
||||
"vibe.core.llm.backend.generic.GenericBackend.complete",
|
||||
side_effect=mock_complete,
|
||||
).start()
|
||||
patch(
|
||||
"vibe.core.llm.backend.mistral.MistralBackend.complete_streaming",
|
||||
side_effect=mock_complete_streaming,
|
||||
).start()
|
||||
patch(
|
||||
"vibe.core.llm.backend.generic.GenericBackend.complete_streaming",
|
||||
side_effect=mock_complete_streaming,
|
||||
).start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mock_llm_output()
|
||||
|
||||
from vibe.acp.entrypoint import main
|
||||
|
||||
main()
|
||||
42
tests/mock/utils.py
Normal file
42
tests/mock/utils.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from vibe.core.types import LLMChunk, LLMMessage, LLMUsage, Role, ToolCall
|
||||
|
||||
MOCK_DATA_ENV_VAR = "VIBE_MOCK_LLM_DATA"
|
||||
|
||||
|
||||
def mock_llm_chunk(
|
||||
content: str = "Hello!",
|
||||
role: Role = Role.assistant,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
finish_reason: str | None = None,
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 5,
|
||||
) -> LLMChunk:
|
||||
message = LLMMessage(
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
name=name,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
return LLMChunk(
|
||||
message=message,
|
||||
usage=LLMUsage(
|
||||
prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
|
||||
def get_mocking_env(mock_chunks: list[LLMChunk] | None = None) -> dict[str, str]:
|
||||
if mock_chunks is None:
|
||||
mock_chunks = [mock_llm_chunk()]
|
||||
|
||||
mock_data = [LLMChunk.model_dump(mock_chunk) for mock_chunk in mock_chunks]
|
||||
|
||||
return {MOCK_DATA_ENV_VAR: json.dumps(mock_data)}
|
||||
69
tests/onboarding/test_run_onboarding.py
Normal file
69
tests/onboarding/test_run_onboarding.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import override
|
||||
|
||||
import pytest
|
||||
from textual.app import App
|
||||
|
||||
from vibe.setup import onboarding
|
||||
|
||||
|
||||
class StubApp(App[str | None]):
|
||||
def __init__(self, return_value: str | None) -> None:
|
||||
super().__init__()
|
||||
self._return_value = return_value
|
||||
|
||||
@override
|
||||
def run(self, *args: object, **kwargs: object) -> str | None:
|
||||
return self._return_value
|
||||
|
||||
|
||||
def _patch_env_file(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Path:
|
||||
env_file = tmp_path / ".env"
|
||||
monkeypatch.setattr(onboarding, "GLOBAL_ENV_FILE", env_file, raising=False)
|
||||
return env_file
|
||||
|
||||
|
||||
def _exit_raiser(code: int = 0) -> None:
|
||||
raise SystemExit(code)
|
||||
|
||||
|
||||
def test_exits_on_cancel(
|
||||
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
|
||||
) -> None:
|
||||
_patch_env_file(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr(sys, "exit", _exit_raiser)
|
||||
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
onboarding.run_onboarding(StubApp(None))
|
||||
|
||||
assert excinfo.value.code == 0
|
||||
out = capsys.readouterr().out
|
||||
assert "Setup cancelled. See you next time!" in out
|
||||
|
||||
|
||||
def test_warns_on_save_error(
|
||||
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
|
||||
) -> None:
|
||||
_patch_env_file(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr(sys, "exit", _exit_raiser)
|
||||
|
||||
onboarding.run_onboarding(StubApp("save_error:disk full"))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Could not save API key" in out
|
||||
assert "disk full" in out
|
||||
|
||||
|
||||
def test_successfully_completes(
|
||||
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], tmp_path: Path
|
||||
) -> None:
|
||||
_patch_env_file(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr(sys, "exit", _exit_raiser)
|
||||
|
||||
onboarding.run_onboarding(StubApp("completed"))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert out == ""
|
||||
124
tests/onboarding/test_ui_onboarding.py
Normal file
124
tests/onboarding/test_ui_onboarding.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from textual.events import Resize
|
||||
from textual.geometry import Size
|
||||
from textual.pilot import Pilot
|
||||
from textual.widgets import Input
|
||||
|
||||
from vibe.core import config as core_config
|
||||
from vibe.setup.onboarding import OnboardingApp
|
||||
import vibe.setup.onboarding.screens.api_key as api_key_module
|
||||
from vibe.setup.onboarding.screens.api_key import ApiKeyScreen
|
||||
from vibe.setup.onboarding.screens.theme_selection import THEMES, ThemeSelectionScreen
|
||||
|
||||
|
||||
async def _wait_for(
|
||||
condition: Callable[[], bool],
|
||||
pilot: Pilot,
|
||||
timeout: float = 5.0,
|
||||
interval: float = 0.05,
|
||||
) -> None:
|
||||
elapsed = 0.0
|
||||
while not condition():
|
||||
await pilot.pause(interval)
|
||||
if (elapsed := elapsed + interval) >= timeout:
|
||||
msg = "Timed out waiting for condition."
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def onboarding_app(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> tuple[OnboardingApp, Path, dict[str, Any]]:
|
||||
vibe_home = tmp_path / ".vibe"
|
||||
env_file = vibe_home / ".env"
|
||||
saved_updates: dict[str, Any] = {}
|
||||
|
||||
def record_updates(updates: dict[str, Any]) -> None:
|
||||
saved_updates.update(updates)
|
||||
|
||||
monkeypatch.setenv("VIBE_HOME", str(vibe_home))
|
||||
|
||||
for module in (core_config, api_key_module):
|
||||
monkeypatch.setattr(module, "GLOBAL_CONFIG_DIR", vibe_home, raising=False)
|
||||
monkeypatch.setattr(module, "GLOBAL_ENV_FILE", env_file, raising=False)
|
||||
|
||||
monkeypatch.setattr(
|
||||
core_config.VibeConfig,
|
||||
"save_updates",
|
||||
classmethod(lambda cls, updates: record_updates(updates)),
|
||||
)
|
||||
|
||||
return OnboardingApp(), env_file, saved_updates
|
||||
|
||||
|
||||
async def pass_welcome_screen(pilot: Pilot) -> None:
|
||||
welcome_screen = pilot.app.get_screen("welcome")
|
||||
await _wait_for(
|
||||
lambda: not welcome_screen.query_one("#enter-hint").has_class("hidden"), pilot
|
||||
)
|
||||
await pilot.press("enter")
|
||||
await _wait_for(lambda: isinstance(pilot.app.screen, ThemeSelectionScreen), pilot)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_gets_through_the_onboarding_successfully(
|
||||
onboarding_app: tuple[OnboardingApp, Path, dict[str, Any]],
|
||||
) -> None:
|
||||
app, env_file, config_updates = onboarding_app
|
||||
api_key_value = "sk-onboarding-test-key"
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pass_welcome_screen(pilot)
|
||||
|
||||
await pilot.press("enter")
|
||||
await _wait_for(lambda: isinstance(app.screen, ApiKeyScreen), pilot)
|
||||
api_screen = app.screen
|
||||
input_widget = api_screen.query_one("#key", Input)
|
||||
await pilot.press(*api_key_value)
|
||||
assert input_widget.value == api_key_value
|
||||
|
||||
await pilot.press("enter")
|
||||
await _wait_for(lambda: app.return_value is not None, pilot, timeout=2.0)
|
||||
|
||||
assert app.return_value == "completed"
|
||||
|
||||
assert env_file.is_file()
|
||||
env_contents = env_file.read_text(encoding="utf-8")
|
||||
assert "MISTRAL_API_KEY" in env_contents
|
||||
assert api_key_value in env_contents
|
||||
|
||||
assert config_updates.get("textual_theme") == app.theme
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_can_pick_a_theme_and_saves_selection(
|
||||
onboarding_app: tuple[OnboardingApp, Path, dict[str, Any]],
|
||||
) -> None:
|
||||
app, _, config_updates = onboarding_app
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pass_welcome_screen(pilot)
|
||||
|
||||
theme_screen = app.screen
|
||||
app.post_message(
|
||||
Resize(Size(40, 10), Size(40, 10))
|
||||
) # trigger the resize event handler
|
||||
preview = theme_screen.query_one("#preview")
|
||||
assert preview.styles.max_height is not None
|
||||
target_theme = "gruvbox"
|
||||
assert target_theme in THEMES
|
||||
start_index = THEMES.index(app.theme)
|
||||
target_index = THEMES.index(target_theme)
|
||||
steps_down = (target_index - start_index) % len(THEMES)
|
||||
await pilot.press(*["down"] * steps_down)
|
||||
assert app.theme == target_theme
|
||||
await pilot.press("enter")
|
||||
await _wait_for(lambda: isinstance(app.screen, ApiKeyScreen), pilot)
|
||||
|
||||
assert config_updates.get("textual_theme") == target_theme
|
||||
0
tests/playground/.gitkeep
Normal file
0
tests/playground/.gitkeep
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 22 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 26 KiB |
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 24 KiB |
52
tests/snapshots/base_snapshot_test_app.py
Normal file
52
tests/snapshots/base_snapshot_test_app.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from rich.style import Style
|
||||
from textual.widgets.text_area import TextAreaTheme
|
||||
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.textual_ui.widgets.chat_input import ChatTextArea
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
def default_config() -> VibeConfig:
|
||||
"""Default configuration for snapshot testing.
|
||||
Remove as much interference as possible from the snapshot comparison, in order to get a clean pixel-to-pixel comparison.
|
||||
- Injects a fake backend to prevent (or stub) LLM calls.
|
||||
- Disables the welcome banner animation.
|
||||
- Forces a value for the displayed workdir
|
||||
- Hides the chat input cursor (as the blinking animation is not deterministic).
|
||||
"""
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
textual_theme="gruvbox",
|
||||
disable_welcome_banner_animation=True,
|
||||
displayed_workdir="/test/workdir",
|
||||
)
|
||||
|
||||
|
||||
class BaseSnapshotTestApp(VibeApp):
|
||||
CSS_PATH = "../../vibe/cli/textual_ui/app.tcss"
|
||||
|
||||
def __init__(self, config: VibeConfig | None = None, **kwargs):
|
||||
config = config or default_config()
|
||||
|
||||
super().__init__(config=config, **kwargs)
|
||||
|
||||
self.agent = Agent(
|
||||
config,
|
||||
auto_approve=self.auto_approve,
|
||||
enable_streaming=self.enable_streaming,
|
||||
backend=FakeBackend(),
|
||||
)
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
await super().on_mount()
|
||||
self._hide_chat_input_cursor()
|
||||
|
||||
def _hide_chat_input_cursor(self) -> None:
|
||||
text_area = self.query_one(ChatTextArea)
|
||||
hidden_cursor_theme = TextAreaTheme(name="hidden_cursor", cursor_style=Style())
|
||||
text_area.register_theme(hidden_cursor_theme)
|
||||
text_area.theme = "hidden_cursor"
|
||||
20
tests/snapshots/snap_compare.py
Normal file
20
tests/snapshots/snap_compare.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from pathlib import PurePath
|
||||
from typing import Protocol
|
||||
|
||||
from textual.app import App
|
||||
from textual.pilot import Pilot
|
||||
|
||||
|
||||
class SnapCompare(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
app: str | PurePath | App,
|
||||
/,
|
||||
*,
|
||||
press: Iterable[str] = ...,
|
||||
terminal_size: tuple[int, int] = ...,
|
||||
run_before: (Callable[[Pilot], Awaitable[None] | None] | None) = ...,
|
||||
) -> bool: ...
|
||||
43
tests/snapshots/test_ui_snapshot_basic_conversation.py
Normal file
43
tests/snapshots/test_ui_snapshot_basic_conversation.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from textual.pilot import Pilot
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.snapshots.base_snapshot_test_app import BaseSnapshotTestApp, default_config
|
||||
from tests.snapshots.snap_compare import SnapCompare
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core.agent import Agent
|
||||
|
||||
|
||||
class SnapshotTestAppWithConversation(BaseSnapshotTestApp):
|
||||
def __init__(self) -> None:
|
||||
config = default_config()
|
||||
fake_backend = FakeBackend(
|
||||
results=[
|
||||
mock_llm_chunk(
|
||||
content="I'm the Vibe agent and I'm ready to help.",
|
||||
prompt_tokens=10_000,
|
||||
completion_tokens=2_500,
|
||||
)
|
||||
]
|
||||
)
|
||||
super().__init__(config=config)
|
||||
self.agent = Agent(
|
||||
config,
|
||||
auto_approve=self.auto_approve,
|
||||
enable_streaming=self.enable_streaming,
|
||||
backend=fake_backend,
|
||||
)
|
||||
|
||||
|
||||
def test_snapshot_shows_basic_conversation(snap_compare: SnapCompare) -> None:
|
||||
async def run_before(pilot: Pilot) -> None:
|
||||
await pilot.press(*"Hello there, who are you?")
|
||||
await pilot.press("enter")
|
||||
await pilot.pause(0.4)
|
||||
|
||||
assert snap_compare(
|
||||
"test_ui_snapshot_basic_conversation.py:SnapshotTestAppWithConversation",
|
||||
terminal_size=(120, 36),
|
||||
run_before=run_before,
|
||||
)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from textual.pilot import Pilot
|
||||
from textual.widgets.markdown import MarkdownFence
|
||||
|
||||
from tests.snapshots.snap_compare import SnapCompare
|
||||
from vibe.cli.textual_ui.widgets.messages import AssistantMessage
|
||||
|
||||
|
||||
def test_snapshot_allows_horizontal_scrolling_for_long_code_blocks(
|
||||
snap_compare: SnapCompare,
|
||||
) -> None:
|
||||
assistant_message_md = """Here's a very long print instruction:
|
||||
|
||||
```python
|
||||
def lorem_ipsum():
|
||||
# Print a very long line (Lorem Ipsum)
|
||||
print("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt. Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem.")
|
||||
```
|
||||
|
||||
The `print` statement includes a very long line of Lorem Ipsum text to demonstrate a lengthy output."""
|
||||
|
||||
async def run_before(pilot: Pilot) -> None:
|
||||
app = pilot.app
|
||||
assistant_message = AssistantMessage(assistant_message_md)
|
||||
messages_area = app.query_one("#messages")
|
||||
await messages_area.mount(assistant_message)
|
||||
await assistant_message.write_initial_content()
|
||||
await pilot.pause(0.1)
|
||||
|
||||
markdown_fence = app.query_one(MarkdownFence)
|
||||
markdown_fence.scroll_relative(x=15, immediate=True)
|
||||
|
||||
assert snap_compare(
|
||||
"base_snapshot_test_app.py:BaseSnapshotTestApp",
|
||||
run_before=run_before,
|
||||
terminal_size=(120, 36),
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from textual.pilot import Pilot
|
||||
|
||||
from tests.snapshots.base_snapshot_test_app import BaseSnapshotTestApp, default_config
|
||||
from tests.snapshots.snap_compare import SnapCompare
|
||||
from vibe.cli.update_notifier import FakeVersionUpdateGateway, VersionUpdate
|
||||
|
||||
|
||||
class SnapshotTestAppWithUpdate(BaseSnapshotTestApp):
|
||||
def __init__(self):
|
||||
config = default_config()
|
||||
config.enable_update_checks = True
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version="0.2.0")
|
||||
)
|
||||
super().__init__(config=config, version_update_notifier=version_update_notifier)
|
||||
|
||||
|
||||
def test_snapshot_shows_release_update_notification(snap_compare: SnapCompare) -> None:
|
||||
async def run_before(pilot: Pilot) -> None:
|
||||
await pilot.pause(0.2)
|
||||
|
||||
assert snap_compare(
|
||||
"test_ui_snapshot_release_update_notification.py:SnapshotTestAppWithUpdate",
|
||||
terminal_size=(120, 36),
|
||||
run_before=run_before,
|
||||
)
|
||||
115
tests/stubs/fake_backend.py
Normal file
115
tests/stubs/fake_backend.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from vibe.core.types import LLMChunk, LLMMessage
|
||||
|
||||
|
||||
class FakeBackend:
|
||||
"""Minimal async backend stub to drive Agent.act without network.
|
||||
|
||||
Provide a finite sequence of LLMResult objects to be returned by
|
||||
`complete`. When exhausted, returns an empty assistant message.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
results: Iterable[LLMChunk] | None = None,
|
||||
*,
|
||||
token_counter: Callable[[list[LLMMessage]], int] | None = None,
|
||||
exception_to_raise: Exception | None = None,
|
||||
) -> None:
|
||||
self._chunks = list(results or [])
|
||||
self._requests_messages: list[list[LLMMessage]] = []
|
||||
self._requests_extra_headers: list[dict[str, str] | None] = []
|
||||
self._count_tokens_calls: list[list[LLMMessage]] = []
|
||||
self._token_counter = token_counter or self._default_token_counter
|
||||
self._exception_to_raise = exception_to_raise
|
||||
|
||||
@property
|
||||
def requests_messages(self) -> list[list[LLMMessage]]:
|
||||
return self._requests_messages
|
||||
|
||||
@property
|
||||
def requests_extra_headers(self) -> list[dict[str, str] | None]:
|
||||
return self._requests_extra_headers
|
||||
|
||||
@staticmethod
|
||||
def _default_token_counter(messages: list[LLMMessage]) -> int:
|
||||
return 1
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature,
|
||||
tools,
|
||||
tool_choice,
|
||||
extra_headers,
|
||||
max_tokens,
|
||||
) -> LLMChunk:
|
||||
if self._exception_to_raise:
|
||||
raise self._exception_to_raise
|
||||
|
||||
self._requests_messages.append(messages)
|
||||
self._requests_extra_headers.append(extra_headers)
|
||||
if self._chunks:
|
||||
chunk = self._chunks.pop(0)
|
||||
if not self._chunks:
|
||||
chunk = chunk.model_copy(update={"finish_reason": "stop"})
|
||||
return chunk
|
||||
return mock_llm_chunk(content="", finish_reason="stop")
|
||||
|
||||
async def complete_streaming(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature,
|
||||
tools,
|
||||
tool_choice,
|
||||
extra_headers,
|
||||
max_tokens,
|
||||
) -> AsyncGenerator[LLMChunk]:
|
||||
if self._exception_to_raise:
|
||||
raise self._exception_to_raise
|
||||
|
||||
self._requests_messages.append(messages)
|
||||
self._requests_extra_headers.append(extra_headers)
|
||||
has_final_chunk = False
|
||||
while self._chunks:
|
||||
chunk = self._chunks.pop(0)
|
||||
is_last_provided_chunk = not self._chunks
|
||||
if is_last_provided_chunk:
|
||||
chunk = chunk.model_copy(update={"finish_reason": "stop"})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
has_final_chunk = True
|
||||
|
||||
yield chunk
|
||||
if has_final_chunk:
|
||||
break
|
||||
|
||||
if not has_final_chunk:
|
||||
yield mock_llm_chunk(content="", finish_reason="stop")
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
*,
|
||||
model,
|
||||
messages,
|
||||
temperature=0.0,
|
||||
tools,
|
||||
tool_choice=None,
|
||||
extra_headers,
|
||||
) -> int:
|
||||
self._count_tokens_calls.append(list(messages))
|
||||
return self._token_counter(messages)
|
||||
86
tests/stubs/fake_connection.py
Normal file
86
tests/stubs/fake_connection.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from acp import (
|
||||
Agent,
|
||||
AgentSideConnection,
|
||||
CreateTerminalRequest,
|
||||
KillTerminalCommandRequest,
|
||||
KillTerminalCommandResponse,
|
||||
ReadTextFileRequest,
|
||||
ReadTextFileResponse,
|
||||
ReleaseTerminalRequest,
|
||||
ReleaseTerminalResponse,
|
||||
RequestPermissionRequest,
|
||||
RequestPermissionResponse,
|
||||
SessionNotification,
|
||||
TerminalHandle,
|
||||
TerminalOutputRequest,
|
||||
TerminalOutputResponse,
|
||||
WaitForTerminalExitRequest,
|
||||
WaitForTerminalExitResponse,
|
||||
WriteTextFileRequest,
|
||||
WriteTextFileResponse,
|
||||
)
|
||||
|
||||
|
||||
class FakeAgentSideConnection(AgentSideConnection):
|
||||
def __init__(self, to_agent: Callable[[AgentSideConnection], Agent]) -> None:
|
||||
self._session_updates = []
|
||||
to_agent(self)
|
||||
|
||||
async def sessionUpdate(self, params: SessionNotification) -> None:
|
||||
self._session_updates.append(params)
|
||||
|
||||
async def requestPermission(
|
||||
self, params: RequestPermissionRequest
|
||||
) -> RequestPermissionResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def readTextFile(self, params: ReadTextFileRequest) -> ReadTextFileResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def writeTextFile(
|
||||
self, params: WriteTextFileRequest
|
||||
) -> WriteTextFileResponse | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def createTerminal(self, params: CreateTerminalRequest) -> TerminalHandle:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def terminalOutput(
|
||||
self, params: TerminalOutputRequest
|
||||
) -> TerminalOutputResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def releaseTerminal(
|
||||
self, params: ReleaseTerminalRequest
|
||||
) -> ReleaseTerminalResponse | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def waitForTerminalExit(
|
||||
self, params: WaitForTerminalExitRequest
|
||||
) -> WaitForTerminalExitResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def killTerminal(
|
||||
self, params: KillTerminalCommandRequest
|
||||
) -> KillTerminalCommandResponse | None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def extMethod(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def extNotification(self, method: str, params: dict[str, Any]) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def close(self) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def __aenter__(self) -> AgentSideConnection:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
await self.close()
|
||||
30
tests/stubs/fake_tool.py
Normal file
30
tests/stubs/fake_tool.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vibe.core.tools.base import BaseTool, BaseToolConfig, BaseToolState
|
||||
|
||||
|
||||
class FakeToolArgs(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class FakeToolResult(BaseModel):
|
||||
message: str = "fake tool executed"
|
||||
|
||||
|
||||
class FakeToolState(BaseToolState):
|
||||
pass
|
||||
|
||||
|
||||
class FakeTool(BaseTool[FakeToolArgs, FakeToolResult, BaseToolConfig, FakeToolState]):
|
||||
_exception_to_raise: BaseException | None = None
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "stub_tool"
|
||||
|
||||
async def run(self, args: FakeToolArgs) -> FakeToolResult:
|
||||
if self._exception_to_raise:
|
||||
raise self._exception_to_raise
|
||||
return FakeToolResult()
|
||||
56
tests/test_agent_auto_compact.py
Normal file
56
tests/test_agent_auto_compact.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
from vibe.core.types import (
|
||||
AssistantEvent,
|
||||
CompactEndEvent,
|
||||
CompactStartEvent,
|
||||
LLMMessage,
|
||||
Role,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_triggers_and_batches_observer() -> None:
|
||||
observed: list[tuple[Role, str | None]] = []
|
||||
|
||||
def observer(msg: LLMMessage) -> None:
|
||||
observed.append((msg.role, msg.content))
|
||||
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="<summary>"),
|
||||
mock_llm_chunk(content="<final>"),
|
||||
])
|
||||
cfg = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False), auto_compact_threshold=1
|
||||
)
|
||||
agent = Agent(cfg, message_observer=observer, backend=backend)
|
||||
agent.stats.context_tokens = 2
|
||||
|
||||
events = [ev async for ev in agent.act("Hello")]
|
||||
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], CompactStartEvent)
|
||||
assert isinstance(events[1], CompactEndEvent)
|
||||
assert isinstance(events[2], AssistantEvent)
|
||||
start: CompactStartEvent = events[0]
|
||||
end: CompactEndEvent = events[1]
|
||||
final: AssistantEvent = events[2]
|
||||
assert start.current_context_tokens == 2
|
||||
assert start.threshold == 1
|
||||
assert end.old_context_tokens == 2
|
||||
assert end.new_context_tokens >= 1
|
||||
assert final.content == "<final>"
|
||||
|
||||
roles = [r for r, _ in observed]
|
||||
assert roles == [Role.system, Role.user, Role.assistant]
|
||||
assert (
|
||||
observed[1][1] is not None
|
||||
and "Last request from user was: Hello" in observed[1][1]
|
||||
)
|
||||
assert observed[2][1] == "<final>"
|
||||
77
tests/test_agent_backend.py
Normal file
77
tests/test_agent_backend.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config() -> VibeConfig:
|
||||
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_x_affinity_header_when_asking_an_answer(vibe_config: VibeConfig):
|
||||
backend = FakeBackend([mock_llm_chunk(content="Response", finish_reason="stop")])
|
||||
agent = Agent(vibe_config, backend=backend)
|
||||
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
|
||||
assert len(backend.requests_extra_headers) > 0
|
||||
headers = backend.requests_extra_headers[0]
|
||||
assert headers is not None
|
||||
assert "x-affinity" in headers
|
||||
assert headers["x-affinity"] == agent.session_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_x_affinity_header_when_asking_an_answer_streaming(
|
||||
vibe_config: VibeConfig,
|
||||
):
|
||||
backend = FakeBackend([mock_llm_chunk(content="Response", finish_reason="stop")])
|
||||
agent = Agent(vibe_config, backend=backend, enable_streaming=True)
|
||||
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
|
||||
assert len(backend.requests_extra_headers) > 0
|
||||
headers = backend.requests_extra_headers[0]
|
||||
assert headers is not None
|
||||
assert "x-affinity" in headers
|
||||
assert headers["x-affinity"] == agent.session_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_tokens_stats_based_on_backend_response(vibe_config: VibeConfig):
|
||||
chunk = mock_llm_chunk(
|
||||
content="Response",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
)
|
||||
backend = FakeBackend([chunk])
|
||||
agent = Agent(vibe_config, backend=backend)
|
||||
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
|
||||
assert agent.stats.context_tokens == 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_tokens_stats_based_on_backend_response_streaming(
|
||||
vibe_config: VibeConfig,
|
||||
):
|
||||
final_chunk = mock_llm_chunk(
|
||||
content="Complete",
|
||||
finish_reason="stop",
|
||||
prompt_tokens=200,
|
||||
completion_tokens=75,
|
||||
)
|
||||
backend = FakeBackend([final_chunk])
|
||||
agent = Agent(vibe_config, backend=backend, enable_streaming=True)
|
||||
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
|
||||
assert agent.stats.context_tokens == 275
|
||||
430
tests/test_agent_observer_streaming.py
Normal file
430
tests/test_agent_observer_streaming.py
Normal file
@@ -0,0 +1,430 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
from vibe.core.llm.types import BackendLike
|
||||
from vibe.core.middleware import (
|
||||
ConversationContext,
|
||||
MiddlewareAction,
|
||||
MiddlewarePipeline,
|
||||
MiddlewareResult,
|
||||
ResetReason,
|
||||
)
|
||||
from vibe.core.tools.base import BaseToolConfig, ToolPermission
|
||||
from vibe.core.tools.builtins.todo import TodoArgs
|
||||
from vibe.core.types import (
|
||||
AssistantEvent,
|
||||
FunctionCall,
|
||||
LLMChunk,
|
||||
LLMMessage,
|
||||
Role,
|
||||
ToolCall,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from vibe.core.utils import (
|
||||
ApprovalResponse,
|
||||
CancellationReason,
|
||||
get_user_cancellation_message,
|
||||
)
|
||||
|
||||
|
||||
class InjectBeforeMiddleware:
|
||||
injectedMessage = "<injected>"
|
||||
|
||||
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
||||
"Inject a message just before the current step executes."
|
||||
return MiddlewareResult(
|
||||
action=MiddlewareAction.INJECT_MESSAGE, message=self.injectedMessage
|
||||
)
|
||||
|
||||
async def after_turn(self, context: ConversationContext) -> MiddlewareResult:
|
||||
return MiddlewareResult()
|
||||
|
||||
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def make_config(
|
||||
*,
|
||||
disable_logging: bool = True,
|
||||
enabled_tools: list[str] | None = None,
|
||||
tools: dict[str, BaseToolConfig] | None = None,
|
||||
) -> VibeConfig:
|
||||
cfg = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=not disable_logging),
|
||||
auto_compact_threshold=0,
|
||||
system_prompt_id="tests",
|
||||
include_project_context=False,
|
||||
include_prompt_detail=False,
|
||||
include_model_info=False,
|
||||
enabled_tools=enabled_tools or [],
|
||||
tools=tools or {},
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def observer_capture() -> tuple[
|
||||
list[tuple[Role, str | None]], Callable[[LLMMessage], None]
|
||||
]:
|
||||
observed: list[tuple[Role, str | None]] = []
|
||||
|
||||
def observer(msg: LLMMessage) -> None:
|
||||
observed.append((msg.role, msg.content))
|
||||
|
||||
return observed, observer
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_flushes_batched_messages_with_injection_middleware(
|
||||
observer_capture,
|
||||
) -> None:
|
||||
observed, observer = observer_capture
|
||||
|
||||
backend = FakeBackend([mock_llm_chunk(content="I can write very efficient code.")])
|
||||
agent = Agent(make_config(), message_observer=observer, backend=backend)
|
||||
agent.middleware_pipeline.add(InjectBeforeMiddleware())
|
||||
|
||||
async for _ in agent.act("How can you help?"):
|
||||
pass
|
||||
|
||||
assert len(observed) == 3
|
||||
assert [r for r, _ in observed] == [Role.system, Role.user, Role.assistant]
|
||||
assert observed[0][1] == "You are Vibe, a super useful programming assistant."
|
||||
# injected content should be appended to the user's message before emission
|
||||
assert (
|
||||
observed[1][1]
|
||||
== f"How can you help?\n\n{InjectBeforeMiddleware.injectedMessage}"
|
||||
)
|
||||
assert observed[2][1] == "I can write very efficient code."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_action_flushes_user_msg_before_returning(observer_capture) -> None:
|
||||
observed, observer = observer_capture
|
||||
|
||||
# max_turns=0 forces an immediate STOP on the first before_turn
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="My response will never reach you...")
|
||||
])
|
||||
agent = Agent(
|
||||
make_config(), message_observer=observer, max_turns=0, backend=backend
|
||||
)
|
||||
|
||||
async for _ in agent.act("Greet."):
|
||||
pass
|
||||
|
||||
assert len(observed) == 2
|
||||
# user's message should have been flushed before returning
|
||||
assert [r for r, _ in observed] == [Role.system, Role.user]
|
||||
assert observed[0][1] == "You are Vibe, a super useful programming assistant."
|
||||
assert observed[1][1] == "Greet."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_emits_user_and_assistant_msgs(observer_capture) -> None:
|
||||
observed, observer = observer_capture
|
||||
|
||||
backend = FakeBackend([mock_llm_chunk(content="Pong!")])
|
||||
agent = Agent(make_config(), message_observer=observer, backend=backend)
|
||||
|
||||
async for _ in agent.act("Ping?"):
|
||||
pass
|
||||
|
||||
assert len(observed) == 3
|
||||
assert [r for r, _ in observed] == [Role.system, Role.user, Role.assistant]
|
||||
assert observed[1][1] == "Ping?"
|
||||
assert observed[2][1] == "Pong!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_yields_assistant_event_with_usage_stats() -> None:
|
||||
backend = FakeBackend([mock_llm_chunk(content="Pong!")])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
events = [ev async for ev in agent.act("Ping?")]
|
||||
|
||||
assert len(events) == 1
|
||||
ev = events[-1]
|
||||
assert isinstance(ev, AssistantEvent)
|
||||
assert ev.content == "Pong!"
|
||||
# stats come from tests.mock.utils.mock_llm_result (prompt=10, completion=5)
|
||||
assert ev.prompt_tokens == 10
|
||||
assert ev.completion_tokens == 5
|
||||
assert ev.session_total_tokens == 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_streams_batched_chunks_in_order() -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Hello"),
|
||||
mock_llm_chunk(content=" from"),
|
||||
mock_llm_chunk(content=" Vibe"),
|
||||
mock_llm_chunk(content="! "),
|
||||
mock_llm_chunk(content="More"),
|
||||
mock_llm_chunk(content=" and"),
|
||||
mock_llm_chunk(content=" end"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend, enable_streaming=True)
|
||||
|
||||
events = [event async for event in agent.act("Stream, please.")]
|
||||
|
||||
assert len(events) == 2
|
||||
assert [event.content for event in events if isinstance(event, AssistantEvent)] == [
|
||||
"Hello from Vibe! More",
|
||||
" and end",
|
||||
]
|
||||
assert agent.messages[-1].role == Role.assistant
|
||||
assert agent.messages[-1].content == "Hello from Vibe! More and end"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_handles_streaming_with_tool_call_events_in_sequence() -> None:
|
||||
todo_tool_call = ToolCall(
|
||||
id="tc_stream",
|
||||
index=0,
|
||||
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
|
||||
)
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Checking your todos."),
|
||||
mock_llm_chunk(content="", tool_calls=[todo_tool_call]),
|
||||
mock_llm_chunk(content="", finish_reason="stop"),
|
||||
mock_llm_chunk(content="Done reviewing todos."),
|
||||
])
|
||||
agent = Agent(
|
||||
make_config(
|
||||
enabled_tools=["todo"],
|
||||
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
|
||||
),
|
||||
backend=backend,
|
||||
auto_approve=True,
|
||||
enable_streaming=True,
|
||||
)
|
||||
|
||||
events = [event async for event in agent.act("What about my todos?")]
|
||||
|
||||
assert [type(event) for event in events] == [
|
||||
AssistantEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
AssistantEvent,
|
||||
]
|
||||
assert isinstance(events[0], AssistantEvent)
|
||||
assert events[0].content == "Checking your todos."
|
||||
assert isinstance(events[1], ToolCallEvent)
|
||||
assert events[1].tool_name == "todo"
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is None
|
||||
assert events[2].skipped is False
|
||||
assert isinstance(events[3], AssistantEvent)
|
||||
assert events[3].content == "Done reviewing todos."
|
||||
assert agent.messages[-1].content == "Done reviewing todos."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_handles_tool_call_chunk_with_content() -> None:
|
||||
todo_tool_call = ToolCall(
|
||||
id="tc_content",
|
||||
index=0,
|
||||
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
|
||||
)
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Preparing "),
|
||||
mock_llm_chunk(content="todo request", tool_calls=[todo_tool_call]),
|
||||
mock_llm_chunk(content=" complete", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(
|
||||
make_config(
|
||||
enabled_tools=["todo"],
|
||||
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
|
||||
),
|
||||
backend=backend,
|
||||
auto_approve=True,
|
||||
enable_streaming=True,
|
||||
)
|
||||
|
||||
events = [event async for event in agent.act("Check todos with content.")]
|
||||
|
||||
assert [type(event) for event in events] == [
|
||||
AssistantEvent,
|
||||
AssistantEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
]
|
||||
assert isinstance(events[0], AssistantEvent)
|
||||
assert events[0].content == "Preparing todo request"
|
||||
assert isinstance(events[1], AssistantEvent)
|
||||
assert events[1].content == " complete"
|
||||
assert any(
|
||||
m.role == Role.assistant and m.content == "Preparing todo request complete"
|
||||
for m in agent.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_merges_streamed_tool_call_arguments() -> None:
|
||||
tool_call_part_one = ToolCall(
|
||||
id="tc_merge",
|
||||
index=0,
|
||||
function=FunctionCall(
|
||||
name="todo", arguments='{"action": "read", "note": "First '
|
||||
),
|
||||
)
|
||||
tool_call_part_two = ToolCall(
|
||||
id="tc_merge", index=0, function=FunctionCall(name="todo", arguments='part"}')
|
||||
)
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Planning: "),
|
||||
mock_llm_chunk(content="", tool_calls=[tool_call_part_one]),
|
||||
mock_llm_chunk(content="", tool_calls=[tool_call_part_two]),
|
||||
])
|
||||
agent = Agent(
|
||||
make_config(
|
||||
enabled_tools=["todo"],
|
||||
tools={"todo": BaseToolConfig(permission=ToolPermission.ALWAYS)},
|
||||
),
|
||||
backend=backend,
|
||||
auto_approve=True,
|
||||
enable_streaming=True,
|
||||
)
|
||||
|
||||
events = [event async for event in agent.act("Merge streamed tool call args.")]
|
||||
|
||||
assert [type(event) for event in events] == [
|
||||
AssistantEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
]
|
||||
call_event = events[1]
|
||||
assert isinstance(call_event, ToolCallEvent)
|
||||
assert call_event.tool_call_id == "tc_merge"
|
||||
call_args = cast(TodoArgs, call_event.args)
|
||||
assert call_args.action == "read"
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is None
|
||||
assert events[2].skipped is False
|
||||
assistant_with_calls = next(
|
||||
m for m in agent.messages if m.role == Role.assistant and m.tool_calls
|
||||
)
|
||||
reconstructed_calls = assistant_with_calls.tool_calls or []
|
||||
assert len(reconstructed_calls) == 1
|
||||
assert reconstructed_calls[0].function.arguments == (
|
||||
'{"action": "read", "note": "First part"}'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_raises_when_stream_never_signals_finish() -> None:
|
||||
class IncompleteStreamingBackend(BackendLike):
|
||||
def __init__(self, chunks: list[LLMChunk]) -> None:
|
||||
self._chunks = list(chunks)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def complete_streaming(self, **_: object):
|
||||
while self._chunks:
|
||||
yield self._chunks.pop(0)
|
||||
|
||||
async def complete(self, **_: object):
|
||||
return mock_llm_chunk(content="", finish_reason="stop")
|
||||
|
||||
async def count_tokens(self, **_: object) -> int:
|
||||
return 0
|
||||
|
||||
backend = IncompleteStreamingBackend([mock_llm_chunk(content="partial")])
|
||||
agent = Agent(make_config(), backend=backend, enable_streaming=True)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Streamed completion returned no chunks"):
|
||||
[event async for event in agent.act("Will this finish?")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_handles_user_cancellation_during_streaming() -> None:
|
||||
class CountingMiddleware(MiddlewarePipeline):
|
||||
def __init__(self) -> None:
|
||||
self.before_calls = 0
|
||||
self.after_calls = 0
|
||||
|
||||
async def before_turn(self, context: ConversationContext) -> MiddlewareResult:
|
||||
self.before_calls += 1
|
||||
return MiddlewareResult()
|
||||
|
||||
async def after_turn(self, context: ConversationContext) -> MiddlewareResult:
|
||||
self.after_calls += 1
|
||||
return MiddlewareResult()
|
||||
|
||||
def reset(self, reset_reason: ResetReason = ResetReason.STOP) -> None:
|
||||
return None
|
||||
|
||||
todo_tool_call = ToolCall(
|
||||
id="tc_cancel",
|
||||
index=0,
|
||||
function=FunctionCall(name="todo", arguments='{"action": "read"}'),
|
||||
)
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Preparing "),
|
||||
mock_llm_chunk(content="todo request", tool_calls=[todo_tool_call]),
|
||||
mock_llm_chunk(content="", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(
|
||||
make_config(
|
||||
enabled_tools=["todo"],
|
||||
tools={"todo": BaseToolConfig(permission=ToolPermission.ASK)},
|
||||
),
|
||||
backend=backend,
|
||||
auto_approve=False,
|
||||
enable_streaming=True,
|
||||
)
|
||||
middleware = CountingMiddleware()
|
||||
agent.middleware_pipeline.add(middleware)
|
||||
agent.set_approval_callback(
|
||||
lambda _name, _args, _id: (
|
||||
ApprovalResponse.NO,
|
||||
str(get_user_cancellation_message(CancellationReason.OPERATION_CANCELLED)),
|
||||
)
|
||||
)
|
||||
agent.interaction_logger.save_interaction = AsyncMock(return_value=None)
|
||||
|
||||
events = [event async for event in agent.act("Cancel mid stream?")]
|
||||
|
||||
assert [type(event) for event in events] == [
|
||||
AssistantEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
]
|
||||
assert middleware.before_calls == 1
|
||||
assert middleware.after_calls == 0
|
||||
assert isinstance(events[-1], ToolResultEvent)
|
||||
assert events[-1].skipped is True
|
||||
assert events[-1].skip_reason is not None
|
||||
assert "<user_cancellation>" in events[-1].skip_reason
|
||||
assert agent.interaction_logger.save_interaction.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_act_flushes_and_logs_when_streaming_errors(observer_capture) -> None:
|
||||
observed, observer = observer_capture
|
||||
backend = FakeBackend(exception_to_raise=RuntimeError("boom in streaming"))
|
||||
agent = Agent(
|
||||
make_config(), backend=backend, message_observer=observer, enable_streaming=True
|
||||
)
|
||||
agent.interaction_logger.save_interaction = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom in streaming"):
|
||||
[_ async for _ in agent.act("Trigger stream failure")]
|
||||
|
||||
assert [role for role, _ in observed] == [Role.system, Role.user]
|
||||
assert agent.interaction_logger.save_interaction.await_count == 1
|
||||
711
tests/test_agent_stats.py
Normal file
711
tests/test_agent_stats.py
Normal file
@@ -0,0 +1,711 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import (
|
||||
Backend,
|
||||
ModelConfig,
|
||||
ProviderConfig,
|
||||
SessionLoggingConfig,
|
||||
VibeConfig,
|
||||
)
|
||||
from vibe.core.tools.base import BaseToolConfig, ToolPermission
|
||||
from vibe.core.types import (
|
||||
AgentStats,
|
||||
AssistantEvent,
|
||||
CompactEndEvent,
|
||||
CompactStartEvent,
|
||||
FunctionCall,
|
||||
LLMMessage,
|
||||
Role,
|
||||
ToolCall,
|
||||
)
|
||||
|
||||
|
||||
def make_config(
|
||||
*,
|
||||
system_prompt_id: str = "tests",
|
||||
active_model: str = "devstral-latest",
|
||||
input_price: float = 0.4,
|
||||
output_price: float = 2.0,
|
||||
disable_logging: bool = True,
|
||||
auto_compact_threshold: int = 0,
|
||||
include_project_context: bool = False,
|
||||
include_prompt_detail: bool = False,
|
||||
enabled_tools: list[str] | None = None,
|
||||
todo_permission: ToolPermission = ToolPermission.ALWAYS,
|
||||
) -> VibeConfig:
|
||||
models = [
|
||||
ModelConfig(
|
||||
name="mistral-vibe-cli-latest",
|
||||
provider="mistral",
|
||||
alias="devstral-latest",
|
||||
input_price=input_price,
|
||||
output_price=output_price,
|
||||
),
|
||||
ModelConfig(
|
||||
name="devstral-small-latest",
|
||||
provider="mistral",
|
||||
alias="devstral-small",
|
||||
input_price=0.1,
|
||||
output_price=0.3,
|
||||
),
|
||||
ModelConfig(
|
||||
name="strawberry",
|
||||
provider="lechat",
|
||||
alias="strawberry",
|
||||
input_price=2.5,
|
||||
output_price=10.0,
|
||||
),
|
||||
]
|
||||
providers = [
|
||||
ProviderConfig(
|
||||
name="mistral",
|
||||
api_base="https://api.mistral.ai/v1",
|
||||
api_key_env_var="MISTRAL_API_KEY",
|
||||
backend=Backend.MISTRAL,
|
||||
),
|
||||
ProviderConfig(
|
||||
name="lechat",
|
||||
api_base="https://api.mistral.ai/v1",
|
||||
api_key_env_var="LECHAT_API_KEY",
|
||||
backend=Backend.MISTRAL,
|
||||
),
|
||||
]
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=not disable_logging),
|
||||
auto_compact_threshold=auto_compact_threshold,
|
||||
system_prompt_id=system_prompt_id,
|
||||
include_project_context=include_project_context,
|
||||
include_prompt_detail=include_prompt_detail,
|
||||
active_model=active_model,
|
||||
models=models,
|
||||
providers=providers,
|
||||
enabled_tools=enabled_tools or [],
|
||||
tools={"todo": BaseToolConfig(permission=todo_permission)},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def observer_capture() -> tuple[list[LLMMessage], Callable[[LLMMessage], None]]:
|
||||
observed: list[LLMMessage] = []
|
||||
|
||||
def observer(msg: LLMMessage) -> None:
|
||||
observed.append(msg)
|
||||
|
||||
return observed, observer
|
||||
|
||||
|
||||
class TestAgentStatsHelpers:
|
||||
def test_update_pricing(self) -> None:
|
||||
stats = AgentStats()
|
||||
stats.update_pricing(1.5, 3.0)
|
||||
assert stats.input_price_per_million == 1.5
|
||||
assert stats.output_price_per_million == 3.0
|
||||
|
||||
def test_reset_context_state_preserves_cumulative(self) -> None:
|
||||
stats = AgentStats(
|
||||
steps=5,
|
||||
session_prompt_tokens=1000,
|
||||
session_completion_tokens=500,
|
||||
tool_calls_succeeded=3,
|
||||
tool_calls_failed=1,
|
||||
context_tokens=800,
|
||||
last_turn_prompt_tokens=100,
|
||||
last_turn_completion_tokens=50,
|
||||
last_turn_duration=1.5,
|
||||
tokens_per_second=33.3,
|
||||
input_price_per_million=0.4,
|
||||
output_price_per_million=2.0,
|
||||
)
|
||||
|
||||
stats.reset_context_state()
|
||||
|
||||
assert stats.steps == 5
|
||||
assert stats.session_prompt_tokens == 1000
|
||||
assert stats.session_completion_tokens == 500
|
||||
assert stats.tool_calls_succeeded == 3
|
||||
assert stats.tool_calls_failed == 1
|
||||
assert stats.input_price_per_million == 0.4
|
||||
assert stats.output_price_per_million == 2.0
|
||||
|
||||
assert stats.context_tokens == 0
|
||||
assert stats.last_turn_prompt_tokens == 0
|
||||
assert stats.last_turn_completion_tokens == 0
|
||||
assert stats.last_turn_duration == 0.0
|
||||
assert stats.tokens_per_second == 0.0
|
||||
|
||||
def test_session_cost_computed_from_current_pricing(self) -> None:
|
||||
stats = AgentStats(
|
||||
session_prompt_tokens=1_000_000,
|
||||
session_completion_tokens=500_000,
|
||||
input_price_per_million=1.0,
|
||||
output_price_per_million=2.0,
|
||||
)
|
||||
# Cost = 1M * $1/M + 0.5M * $2/M = $1 + $1 = $2
|
||||
assert stats.session_cost == 2.0
|
||||
|
||||
stats.update_pricing(2.0, 4.0)
|
||||
# Cost = 1M * $2/M + 0.5M * $4/M = $2 + $2 = $4
|
||||
assert stats.session_cost == 4.0
|
||||
|
||||
|
||||
class TestReloadPreservesStats:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_preserves_session_tokens(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="First response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
old_session_prompt = agent.stats.session_prompt_tokens
|
||||
old_session_completion = agent.stats.session_completion_tokens
|
||||
assert old_session_prompt > 0
|
||||
assert old_session_completion > 0
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert agent.stats.session_prompt_tokens == old_session_prompt
|
||||
assert agent.stats.session_completion_tokens == old_session_completion
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_preserves_tool_call_stats(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Calling tool",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc1",
|
||||
function=FunctionCall(
|
||||
name="todo", arguments='{"action": "read"}'
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
mock_llm_chunk(content="Done", finish_reason="stop"),
|
||||
])
|
||||
config = make_config(enabled_tools=["todo"])
|
||||
agent = Agent(config, auto_approve=True, backend=backend)
|
||||
|
||||
async for _ in agent.act("Check todos"):
|
||||
pass
|
||||
|
||||
assert agent.stats.tool_calls_succeeded == 1
|
||||
assert agent.stats.tool_calls_agreed == 1
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert agent.stats.tool_calls_succeeded == 1
|
||||
assert agent.stats.tool_calls_agreed == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_preserves_steps(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="R1", finish_reason="stop"),
|
||||
mock_llm_chunk(content="R2", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("First"):
|
||||
pass
|
||||
async for _ in agent.act("Second"):
|
||||
pass
|
||||
|
||||
old_steps = agent.stats.steps
|
||||
assert old_steps >= 2
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert agent.stats.steps == old_steps
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_preserves_context_tokens_when_messages_preserved(
|
||||
self,
|
||||
) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
assert agent.stats.context_tokens > 0
|
||||
initial_context_tokens = agent.stats.context_tokens
|
||||
assert len(agent.messages) > 1
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert len(agent.messages) > 1
|
||||
assert agent.stats.context_tokens == initial_context_tokens
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_resets_context_tokens_when_no_messages(self) -> None:
|
||||
backend = FakeBackend([])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
assert len(agent.messages) == 1
|
||||
assert agent.stats.context_tokens == 0
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert len(agent.messages) == 1
|
||||
assert agent.stats.context_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_resets_context_tokens_when_system_prompt_changes(
|
||||
self,
|
||||
) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
config1 = make_config(system_prompt_id="tests")
|
||||
config2 = make_config(system_prompt_id="cli")
|
||||
agent = Agent(config1, backend=backend)
|
||||
[_ async for _ in agent.act("Hello")]
|
||||
assert agent.stats.context_tokens > 0
|
||||
assert len(agent.messages) > 1
|
||||
|
||||
await agent.reload_with_initial_messages(config=config2)
|
||||
|
||||
assert len(agent.messages) > 1
|
||||
assert agent.stats.context_tokens == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_updates_pricing_from_new_model(self, monkeypatch) -> None:
|
||||
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
|
||||
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
config_mistral = make_config(active_model="devstral-latest")
|
||||
agent = Agent(config_mistral, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
assert agent.stats.input_price_per_million == 0.4
|
||||
assert agent.stats.output_price_per_million == 2.0
|
||||
|
||||
config_other = make_config(active_model="strawberry")
|
||||
await agent.reload_with_initial_messages(config=config_other)
|
||||
|
||||
assert agent.stats.input_price_per_million == 2.5
|
||||
assert agent.stats.output_price_per_million == 10.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_accumulates_tokens_across_configs(self, monkeypatch) -> None:
|
||||
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
|
||||
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="First", finish_reason="stop"),
|
||||
mock_llm_chunk(content="After reload", finish_reason="stop"),
|
||||
])
|
||||
config1 = make_config(active_model="devstral-latest")
|
||||
agent = Agent(config1, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
tokens_after_first = (
|
||||
agent.stats.session_prompt_tokens + agent.stats.session_completion_tokens
|
||||
)
|
||||
|
||||
config2 = make_config(active_model="strawberry")
|
||||
await agent.reload_with_initial_messages(config=config2)
|
||||
|
||||
async for _ in agent.act("Continue"):
|
||||
pass
|
||||
|
||||
tokens_after_second = (
|
||||
agent.stats.session_prompt_tokens + agent.stats.session_completion_tokens
|
||||
)
|
||||
assert tokens_after_second > tokens_after_first
|
||||
|
||||
|
||||
class TestReloadPreservesMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_preserves_conversation_messages(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
assert len(agent.messages) == 3
|
||||
old_user_content = agent.messages[1].content
|
||||
old_assistant_content = agent.messages[2].content
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert len(agent.messages) == 3
|
||||
assert agent.messages[0].role == Role.system
|
||||
assert agent.messages[1].role == Role.user
|
||||
assert agent.messages[1].content == old_user_content
|
||||
assert agent.messages[2].role == Role.assistant
|
||||
assert agent.messages[2].content == old_assistant_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_updates_system_prompt_preserves_rest(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
config1 = make_config(system_prompt_id="tests")
|
||||
agent = Agent(config1, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
old_system = agent.messages[0].content
|
||||
old_user = agent.messages[1].content
|
||||
|
||||
config2 = make_config(system_prompt_id="cli")
|
||||
await agent.reload_with_initial_messages(config=config2)
|
||||
|
||||
assert agent.messages[0].content != old_system
|
||||
assert agent.messages[1].content == old_user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_with_no_messages_stays_empty(self) -> None:
|
||||
backend = FakeBackend([])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
assert len(agent.messages) == 1
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert len(agent.messages) == 1
|
||||
assert agent.messages[0].role == Role.system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_notifies_observer_with_all_messages(
|
||||
self, observer_capture
|
||||
) -> None:
|
||||
observed, observer = observer_capture
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), message_observer=observer, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
observed.clear()
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert len(observed) == 3
|
||||
assert observed[0].role == Role.system
|
||||
assert observed[1].role == Role.user
|
||||
assert observed[2].role == Role.assistant
|
||||
|
||||
|
||||
class TestCompactStatsHandling:
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_preserves_cumulative_stats(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="First response", finish_reason="stop"),
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Build something"):
|
||||
pass
|
||||
|
||||
tokens_before_compact = agent.stats.session_prompt_tokens
|
||||
completions_before = agent.stats.session_completion_tokens
|
||||
steps_before = agent.stats.steps
|
||||
|
||||
await agent.compact()
|
||||
|
||||
# Cumulative stats include the compact turn
|
||||
assert agent.stats.session_prompt_tokens > tokens_before_compact
|
||||
assert agent.stats.session_completion_tokens > completions_before
|
||||
assert agent.stats.steps > steps_before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_updates_context_tokens(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Long response " * 100, finish_reason="stop"),
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Do something complex"):
|
||||
pass
|
||||
|
||||
context_before = agent.stats.context_tokens
|
||||
|
||||
await agent.compact()
|
||||
|
||||
assert agent.stats.context_tokens < context_before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_preserves_tool_call_stats(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Using tool",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id="tc1",
|
||||
function=FunctionCall(
|
||||
name="todo", arguments='{"action": "read"}'
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
mock_llm_chunk(content="Done", finish_reason="stop"),
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
])
|
||||
config = make_config(enabled_tools=["todo"])
|
||||
agent = Agent(config, auto_approve=True, backend=backend)
|
||||
|
||||
async for _ in agent.act("Check todos"):
|
||||
pass
|
||||
|
||||
assert agent.stats.tool_calls_succeeded == 1
|
||||
|
||||
await agent.compact()
|
||||
|
||||
assert agent.stats.tool_calls_succeeded == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_resets_session_id(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Long response " * 100, finish_reason="stop"),
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(disable_logging=False), backend=backend)
|
||||
|
||||
original_session_id = agent.session_id
|
||||
original_logger_session_id = agent.interaction_logger.session_id
|
||||
|
||||
assert agent.session_id == original_logger_session_id
|
||||
|
||||
async for _ in agent.act("Do something complex"):
|
||||
pass
|
||||
|
||||
await agent.compact()
|
||||
|
||||
assert agent.session_id != original_session_id
|
||||
assert agent.session_id == agent.interaction_logger.session_id
|
||||
|
||||
|
||||
class TestAutoCompactIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_triggers_and_preserves_stats(self) -> None:
|
||||
observed: list[tuple[Role, str | None]] = []
|
||||
|
||||
def observer(msg: LLMMessage) -> None:
|
||||
observed.append((msg.role, msg.content))
|
||||
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
mock_llm_chunk(content="<final>", finish_reason="stop"),
|
||||
])
|
||||
cfg = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
auto_compact_threshold=1,
|
||||
)
|
||||
agent = Agent(cfg, message_observer=observer, backend=backend)
|
||||
agent.stats.context_tokens = 2
|
||||
|
||||
events = [ev async for ev in agent.act("Hello")]
|
||||
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], CompactStartEvent)
|
||||
assert isinstance(events[1], CompactEndEvent)
|
||||
assert isinstance(events[2], AssistantEvent)
|
||||
|
||||
start: CompactStartEvent = events[0]
|
||||
end: CompactEndEvent = events[1]
|
||||
final: AssistantEvent = events[2]
|
||||
|
||||
assert start.current_context_tokens == 2
|
||||
assert start.threshold == 1
|
||||
assert end.old_context_tokens == 2
|
||||
assert end.new_context_tokens >= 1
|
||||
assert final.content == "<final>"
|
||||
|
||||
roles = [r for r, _ in observed]
|
||||
assert roles == [Role.system, Role.user, Role.assistant]
|
||||
assert (
|
||||
observed[1][1] is not None
|
||||
and "Last request from user was: Hello" in observed[1][1]
|
||||
)
|
||||
|
||||
|
||||
class TestClearHistoryFullReset:
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_fully_resets_stats(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
assert agent.stats.session_prompt_tokens > 0
|
||||
assert agent.stats.steps > 0
|
||||
|
||||
await agent.clear_history()
|
||||
|
||||
assert agent.stats.session_prompt_tokens == 0
|
||||
assert agent.stats.session_completion_tokens == 0
|
||||
assert agent.stats.steps == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_preserves_pricing(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
config = make_config(input_price=0.4, output_price=2.0)
|
||||
agent = Agent(config, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
await agent.clear_history()
|
||||
|
||||
assert agent.stats.input_price_per_million == 0.4
|
||||
assert agent.stats.output_price_per_million == 2.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_removes_messages(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
assert len(agent.messages) == 3
|
||||
|
||||
await agent.clear_history()
|
||||
|
||||
assert len(agent.messages) == 1
|
||||
assert agent.messages[0].role == Role.system
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history_resets_session_id(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
agent = Agent(make_config(disable_logging=False), backend=backend)
|
||||
|
||||
original_session_id = agent.session_id
|
||||
original_logger_session_id = agent.interaction_logger.session_id
|
||||
|
||||
assert agent.session_id == original_logger_session_id
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
await agent.clear_history()
|
||||
|
||||
assert agent.session_id != original_session_id
|
||||
assert agent.session_id == agent.interaction_logger.session_id
|
||||
|
||||
|
||||
class TestStatsEdgeCases:
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_cost_approximation_on_model_change(
|
||||
self, monkeypatch
|
||||
) -> None:
|
||||
monkeypatch.setenv("LECHAT_API_KEY", "mock-key")
|
||||
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Response", finish_reason="stop")
|
||||
])
|
||||
config1 = make_config(active_model="devstral-latest")
|
||||
agent = Agent(config1, backend=backend)
|
||||
|
||||
async for _ in agent.act("Hello"):
|
||||
pass
|
||||
|
||||
cost_before = agent.stats.session_cost
|
||||
|
||||
config2 = make_config(active_model="strawberry")
|
||||
await agent.reload_with_initial_messages(config=config2)
|
||||
|
||||
cost_after = agent.stats.session_cost
|
||||
|
||||
assert cost_after > cost_before
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_reloads_accumulate_correctly(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="R1", finish_reason="stop"),
|
||||
mock_llm_chunk(content="R2", finish_reason="stop"),
|
||||
mock_llm_chunk(content="R3", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("First"):
|
||||
pass
|
||||
tokens1 = agent.stats.session_total_llm_tokens
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
async for _ in agent.act("Second"):
|
||||
pass
|
||||
tokens2 = agent.stats.session_total_llm_tokens
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
async for _ in agent.act("Third"):
|
||||
pass
|
||||
tokens3 = agent.stats.session_total_llm_tokens
|
||||
|
||||
assert tokens1 < tokens2 < tokens3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_then_reload_preserves_both(self) -> None:
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Initial response", finish_reason="stop"),
|
||||
mock_llm_chunk(content="<summary>", finish_reason="stop"),
|
||||
mock_llm_chunk(content="After reload", finish_reason="stop"),
|
||||
])
|
||||
agent = Agent(make_config(), backend=backend)
|
||||
|
||||
async for _ in agent.act("Build something"):
|
||||
pass
|
||||
|
||||
await agent.compact()
|
||||
tokens_after_compact = agent.stats.session_prompt_tokens
|
||||
|
||||
await agent.reload_with_initial_messages()
|
||||
|
||||
assert agent.stats.session_prompt_tokens == tokens_after_compact
|
||||
|
||||
async for _ in agent.act("Continue"):
|
||||
pass
|
||||
|
||||
assert agent.stats.session_prompt_tokens > tokens_after_compact
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_without_config_preserves_current(self) -> None:
|
||||
backend = FakeBackend([])
|
||||
original_config = make_config(active_model="devstral-latest")
|
||||
agent = Agent(original_config, backend=backend)
|
||||
|
||||
await agent.reload_with_initial_messages(config=None)
|
||||
|
||||
assert agent.config.active_model == "devstral-latest"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reload_with_new_config_updates_it(self) -> None:
|
||||
backend = FakeBackend([])
|
||||
original_config = make_config(active_model="devstral-latest")
|
||||
agent = Agent(original_config, backend=backend)
|
||||
|
||||
new_config = make_config(active_model="devstral-small")
|
||||
await agent.reload_with_initial_messages(config=new_config)
|
||||
|
||||
assert agent.config.active_model == "devstral-small"
|
||||
477
tests/test_agent_tool_call.py
Normal file
477
tests/test_agent_tool_call.py
Normal file
@@ -0,0 +1,477 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from tests.stubs.fake_tool import FakeTool
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
from vibe.core.tools.base import BaseToolConfig, ToolPermission
|
||||
from vibe.core.tools.builtins.todo import TodoItem
|
||||
from vibe.core.types import (
|
||||
AssistantEvent,
|
||||
BaseEvent,
|
||||
FunctionCall,
|
||||
LLMMessage,
|
||||
Role,
|
||||
SyncApprovalCallback,
|
||||
ToolCall,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from vibe.core.utils import ApprovalResponse
|
||||
|
||||
|
||||
async def act_and_collect_events(agent: Agent, prompt: str) -> list[BaseEvent]:
|
||||
return [ev async for ev in agent.act(prompt)]
|
||||
|
||||
|
||||
def make_config(todo_permission: ToolPermission = ToolPermission.ALWAYS) -> VibeConfig:
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
auto_compact_threshold=0,
|
||||
enabled_tools=["todo"],
|
||||
tools={"todo": BaseToolConfig(permission=todo_permission)},
|
||||
system_prompt_id="tests",
|
||||
include_project_context=False,
|
||||
include_prompt_detail=False,
|
||||
)
|
||||
|
||||
|
||||
def make_todo_tool_call(call_id: str, action: str = "read") -> ToolCall:
|
||||
return ToolCall(
|
||||
id=call_id,
|
||||
function=FunctionCall(name="todo", arguments=f'{{"action": "{action}"}}'),
|
||||
)
|
||||
|
||||
|
||||
def make_agent(
|
||||
*,
|
||||
auto_approve: bool = True,
|
||||
todo_permission: ToolPermission = ToolPermission.ALWAYS,
|
||||
backend: FakeBackend,
|
||||
approval_callback: SyncApprovalCallback | None = None,
|
||||
) -> Agent:
|
||||
agent = Agent(
|
||||
make_config(todo_permission=todo_permission),
|
||||
auto_approve=auto_approve,
|
||||
backend=backend,
|
||||
)
|
||||
if approval_callback:
|
||||
agent.set_approval_callback(approval_callback)
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_executes_under_auto_approve() -> None:
|
||||
mocked_tool_call_id = "call_1"
|
||||
tool_call = make_todo_tool_call(mocked_tool_call_id)
|
||||
backend = FakeBackend([
|
||||
mock_llm_chunk(content="Let me check your todos.", tool_calls=[tool_call]),
|
||||
mock_llm_chunk(content="I retrieved 0 todos.", finish_reason="stop"),
|
||||
])
|
||||
agent = make_agent(auto_approve=True, backend=backend)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert [type(e) for e in events] == [
|
||||
AssistantEvent,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
AssistantEvent,
|
||||
]
|
||||
assert isinstance(events[0], AssistantEvent)
|
||||
assert events[0].content == "Let me check your todos."
|
||||
assert isinstance(events[1], ToolCallEvent)
|
||||
assert events[1].tool_name == "todo"
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is None
|
||||
assert events[2].skipped is False
|
||||
assert events[2].result is not None
|
||||
assert isinstance(events[3], AssistantEvent)
|
||||
assert events[3].content == "I retrieved 0 todos."
|
||||
# check conversation history
|
||||
tool_msgs = [m for m in agent.messages if m.role == Role.tool]
|
||||
assert len(tool_msgs) == 1
|
||||
assert tool_msgs[-1].tool_call_id == mocked_tool_call_id
|
||||
assert "total_count" in (tool_msgs[-1].content or "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_requires_approval_if_not_auto_approved() -> None:
|
||||
agent = make_agent(
|
||||
auto_approve=False,
|
||||
todo_permission=ToolPermission.ASK,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Let me check your todos.",
|
||||
tool_calls=[make_todo_tool_call("call_2")],
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="I cannot execute the tool without approval.",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert isinstance(events[1], ToolCallEvent)
|
||||
assert events[1].tool_name == "todo"
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].skipped is True
|
||||
assert events[2].error is None
|
||||
assert events[2].result is None
|
||||
assert events[2].skip_reason is not None
|
||||
assert "not permitted" in events[2].skip_reason.lower()
|
||||
assert isinstance(events[3], AssistantEvent)
|
||||
assert events[3].content == "I cannot execute the tool without approval."
|
||||
assert agent.stats.tool_calls_rejected == 1
|
||||
assert agent.stats.tool_calls_agreed == 0
|
||||
assert agent.stats.tool_calls_succeeded == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_approved_by_callback() -> None:
|
||||
def approval_callback(
|
||||
_tool_name: str, _args: dict[str, Any], _tool_call_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
return (ApprovalResponse.YES, None)
|
||||
|
||||
agent = make_agent(
|
||||
auto_approve=False,
|
||||
todo_permission=ToolPermission.ASK,
|
||||
approval_callback=approval_callback,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Let me check your todos.",
|
||||
tool_calls=[make_todo_tool_call("call_3")],
|
||||
),
|
||||
mock_llm_chunk(content="I retrieved 0 todos.", finish_reason="stop"),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].skipped is False
|
||||
assert events[2].error is None
|
||||
assert events[2].result is not None
|
||||
assert agent.stats.tool_calls_agreed == 1
|
||||
assert agent.stats.tool_calls_rejected == 0
|
||||
assert agent.stats.tool_calls_succeeded == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_rejected_when_auto_approve_disabled_and_rejected_by_callback() -> (
|
||||
None
|
||||
):
|
||||
custom_feedback = "User declined tool execution"
|
||||
|
||||
def approval_callback(
|
||||
_tool_name: str, _args: dict[str, Any], _tool_call_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
return (ApprovalResponse.NO, custom_feedback)
|
||||
|
||||
agent = make_agent(
|
||||
auto_approve=False,
|
||||
todo_permission=ToolPermission.ASK,
|
||||
approval_callback=approval_callback,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Let me check your todos.",
|
||||
tool_calls=[make_todo_tool_call("call_4")],
|
||||
),
|
||||
mock_llm_chunk(
|
||||
content="Understood, I won't check the todos.", finish_reason="stop"
|
||||
),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].skipped is True
|
||||
assert events[2].error is None
|
||||
assert events[2].result is None
|
||||
assert events[2].skip_reason == custom_feedback
|
||||
assert agent.stats.tool_calls_rejected == 1
|
||||
assert agent.stats.tool_calls_agreed == 0
|
||||
assert agent.stats.tool_calls_succeeded == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_skipped_when_permission_is_never() -> None:
|
||||
agent = make_agent(
|
||||
auto_approve=False,
|
||||
todo_permission=ToolPermission.NEVER,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Let me check your todos.",
|
||||
tool_calls=[make_todo_tool_call("call_never")],
|
||||
),
|
||||
mock_llm_chunk(content="Tool is disabled.", finish_reason="stop"),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].skipped is True
|
||||
assert events[2].error is None
|
||||
assert events[2].result is None
|
||||
assert events[2].skip_reason is not None
|
||||
assert "permanently disabled" in events[2].skip_reason.lower()
|
||||
tool_msgs = [m for m in agent.messages if m.role == Role.tool and m.name == "todo"]
|
||||
assert len(tool_msgs) == 1
|
||||
assert tool_msgs[0].name == "todo"
|
||||
assert events[2].skip_reason in (tool_msgs[-1].content or "")
|
||||
assert agent.stats.tool_calls_rejected == 1
|
||||
assert agent.stats.tool_calls_agreed == 0
|
||||
assert agent.stats.tool_calls_succeeded == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_approval_always_flips_auto_approve_for_subsequent_calls() -> None:
|
||||
callback_invocations = []
|
||||
|
||||
def approval_callback(
|
||||
tool_name: str, _args: dict[str, Any], _tool_call_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
callback_invocations.append(tool_name)
|
||||
return (ApprovalResponse.ALWAYS, None)
|
||||
|
||||
agent = make_agent(
|
||||
auto_approve=False,
|
||||
todo_permission=ToolPermission.ASK,
|
||||
approval_callback=approval_callback,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="First check.", tool_calls=[make_todo_tool_call("call_first")]
|
||||
),
|
||||
mock_llm_chunk(content="First done.", finish_reason="stop"),
|
||||
mock_llm_chunk(
|
||||
content="Second check.", tool_calls=[make_todo_tool_call("call_second")]
|
||||
),
|
||||
mock_llm_chunk(content="Second done.", finish_reason="stop"),
|
||||
]),
|
||||
)
|
||||
|
||||
events1 = await act_and_collect_events(agent, "First request")
|
||||
events2 = await act_and_collect_events(agent, "Second request")
|
||||
|
||||
assert agent.auto_approve is True
|
||||
assert len(callback_invocations) == 1
|
||||
assert callback_invocations[0] == "todo"
|
||||
assert isinstance(events1[2], ToolResultEvent)
|
||||
assert events1[2].skipped is False
|
||||
assert events1[2].result is not None
|
||||
assert isinstance(events2[2], ToolResultEvent)
|
||||
assert events2[2].skipped is False
|
||||
assert events2[2].result is not None
|
||||
assert agent.stats.tool_calls_rejected == 0
|
||||
assert agent.stats.tool_calls_succeeded == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_invalid_action() -> None:
|
||||
tool_call = ToolCall(
|
||||
id="call_5",
|
||||
function=FunctionCall(name="todo", arguments='{"action": "invalid_action"}'),
|
||||
)
|
||||
agent = make_agent(
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(content="Let me check your todos.", tool_calls=[tool_call]),
|
||||
mock_llm_chunk(
|
||||
content="I encountered an error with the action.", finish_reason="stop"
|
||||
),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "What's my todo list?")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is not None
|
||||
assert events[2].result is None
|
||||
assert "tool_error" in events[2].error.lower()
|
||||
assert agent.stats.tool_calls_failed == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_duplicate_todo_ids() -> None:
|
||||
duplicate_todos = [
|
||||
TodoItem(id="duplicate", content="Task 1"),
|
||||
TodoItem(id="duplicate", content="Task 2"),
|
||||
]
|
||||
tool_call = ToolCall(
|
||||
id="call_6",
|
||||
function=FunctionCall(
|
||||
name="todo",
|
||||
arguments=json.dumps({
|
||||
"action": "write",
|
||||
"todos": [t.model_dump() for t in duplicate_todos],
|
||||
}),
|
||||
),
|
||||
)
|
||||
agent = make_agent(
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(content="Let me write todos.", tool_calls=[tool_call]),
|
||||
mock_llm_chunk(
|
||||
content="I couldn't write todos with duplicate IDs.",
|
||||
finish_reason="stop",
|
||||
),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "Add todos")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is not None
|
||||
assert events[2].result is None
|
||||
assert "unique" in events[2].error.lower()
|
||||
assert agent.stats.tool_calls_failed == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_exceeding_max_todos() -> None:
|
||||
many_todos = [TodoItem(id=f"todo_{i}", content=f"Task {i}") for i in range(150)]
|
||||
tool_call = ToolCall(
|
||||
id="call_7",
|
||||
function=FunctionCall(
|
||||
name="todo",
|
||||
arguments=json.dumps({
|
||||
"action": "write",
|
||||
"todos": [t.model_dump() for t in many_todos],
|
||||
}),
|
||||
),
|
||||
)
|
||||
agent = make_agent(
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(content="Let me write todos.", tool_calls=[tool_call]),
|
||||
mock_llm_chunk(
|
||||
content="I couldn't write that many todos.", finish_reason="stop"
|
||||
),
|
||||
]),
|
||||
)
|
||||
|
||||
events = await act_and_collect_events(agent, "Add todos")
|
||||
|
||||
assert isinstance(events[2], ToolResultEvent)
|
||||
assert events[2].error is not None
|
||||
assert events[2].result is None
|
||||
assert "100" in events[2].error
|
||||
assert agent.stats.tool_calls_failed == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"exception_class",
|
||||
[
|
||||
pytest.param(KeyboardInterrupt, id="keyboard_interrupt"),
|
||||
pytest.param(asyncio.CancelledError, id="asyncio_cancelled"),
|
||||
],
|
||||
)
|
||||
async def test_tool_call_can_be_interrupted(
|
||||
exception_class: type[BaseException],
|
||||
) -> None:
|
||||
tool_call = ToolCall(
|
||||
id="call_8", function=FunctionCall(name="stub_tool", arguments="{}")
|
||||
)
|
||||
config = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
auto_compact_threshold=0,
|
||||
enabled_tools=["stub_tool"],
|
||||
)
|
||||
agent = Agent(
|
||||
config,
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([
|
||||
mock_llm_chunk(content="Let me use the tool.", tool_calls=[tool_call]),
|
||||
mock_llm_chunk(content="Tool execution completed.", finish_reason="stop"),
|
||||
]),
|
||||
)
|
||||
# no dependency injection available => monkey patch
|
||||
agent.tool_manager._available["stub_tool"] = FakeTool
|
||||
stub_tool_instance = agent.tool_manager.get("stub_tool")
|
||||
assert isinstance(stub_tool_instance, FakeTool)
|
||||
stub_tool_instance._exception_to_raise = exception_class()
|
||||
|
||||
events: list[BaseEvent] = []
|
||||
with pytest.raises(exception_class):
|
||||
async for ev in agent.act("Execute tool"):
|
||||
events.append(ev)
|
||||
|
||||
tool_result_event = next(
|
||||
(e for e in events if isinstance(e, ToolResultEvent)), None
|
||||
)
|
||||
assert tool_result_event is not None
|
||||
assert tool_result_event.error is not None
|
||||
assert "execution interrupted by user" in tool_result_event.error.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fill_missing_tool_responses_inserts_placeholders() -> None:
|
||||
agent = Agent(
|
||||
make_config(),
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([mock_llm_chunk(content="ok", finish_reason="stop")]),
|
||||
)
|
||||
tool_calls_messages = [
|
||||
ToolCall(
|
||||
id="tc1", function=FunctionCall(name="todo", arguments='{"action": "read"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="tc2", function=FunctionCall(name="todo", arguments='{"action": "read"}')
|
||||
),
|
||||
]
|
||||
assistant_msg = LLMMessage(
|
||||
role=Role.assistant, content="Calling tools...", tool_calls=tool_calls_messages
|
||||
)
|
||||
agent.messages = [
|
||||
agent.messages[0],
|
||||
assistant_msg,
|
||||
# only one tool responded: the second is missing
|
||||
LLMMessage(
|
||||
role=Role.tool, tool_call_id="tc1", name="todo", content="Retrieved 0 todos"
|
||||
),
|
||||
]
|
||||
|
||||
await act_and_collect_events(agent, "Proceed")
|
||||
|
||||
tool_msgs = [m for m in agent.messages if m.role == Role.tool]
|
||||
assert any(m.tool_call_id == "tc2" for m in tool_msgs)
|
||||
# find placeholder message for tc2
|
||||
placeholder = next(m for m in tool_msgs if m.tool_call_id == "tc2")
|
||||
assert placeholder.name == "todo"
|
||||
assert (
|
||||
placeholder.content
|
||||
== "<user_cancellation>Tool execution interrupted - no response available</user_cancellation>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_assistant_after_tool_appends_understood() -> None:
|
||||
agent = Agent(
|
||||
make_config(),
|
||||
auto_approve=True,
|
||||
backend=FakeBackend([mock_llm_chunk(content="ok", finish_reason="stop")]),
|
||||
)
|
||||
tool_msg = LLMMessage(
|
||||
role=Role.tool, tool_call_id="tc_z", name="todo", content="Done"
|
||||
)
|
||||
agent.messages = [agent.messages[0], tool_msg]
|
||||
|
||||
await act_and_collect_events(agent, "Next")
|
||||
|
||||
# find the seeded tool message and ensure the next message is "Understood."
|
||||
idx = next(i for i, m in enumerate(agent.messages) if m.role == Role.tool)
|
||||
assert agent.messages[idx + 1].role == Role.assistant
|
||||
assert agent.messages[idx + 1].content == "Understood."
|
||||
137
tests/test_cli_programmatic_preload.py
Normal file
137
tests/test_cli_programmatic_preload.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.mock.mock_backend_factory import mock_backend_factory
|
||||
from tests.mock.utils import mock_llm_chunk
|
||||
from tests.stubs.fake_backend import FakeBackend
|
||||
from vibe.core import run_programmatic
|
||||
from vibe.core.config import Backend, SessionLoggingConfig, VibeConfig
|
||||
from vibe.core.types import LLMMessage, OutputFormat, Role
|
||||
|
||||
|
||||
class SpyStreamingFormatter:
|
||||
def __init__(self) -> None:
|
||||
self.emitted: list[tuple[Role, str | None]] = []
|
||||
|
||||
def on_message_added(self, message: LLMMessage) -> None:
|
||||
self.emitted.append((message.role, message.content))
|
||||
|
||||
def on_event(self, _event) -> None: # No-op for this test
|
||||
pass
|
||||
|
||||
def finalize(self) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def test_run_programmatic_preload_streaming_is_batched(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
spy = SpyStreamingFormatter()
|
||||
monkeypatch.setattr(
|
||||
"vibe.core.programmatic.create_formatter", lambda *_args, **_kwargs: spy
|
||||
)
|
||||
|
||||
with mock_backend_factory(
|
||||
Backend.MISTRAL,
|
||||
lambda provider, **kwargs: FakeBackend([
|
||||
mock_llm_chunk(
|
||||
content="Decorators are wrappers that modify function behavior.",
|
||||
finish_reason="stop",
|
||||
)
|
||||
]),
|
||||
):
|
||||
cfg = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
system_prompt_id="tests",
|
||||
include_project_context=False,
|
||||
include_prompt_detail=False,
|
||||
include_model_info=False,
|
||||
)
|
||||
|
||||
previous = [
|
||||
LLMMessage(
|
||||
role=Role.system, content="This system message should be ignored."
|
||||
),
|
||||
LLMMessage(
|
||||
role=Role.user, content="Previously, you told me about decorators."
|
||||
),
|
||||
LLMMessage(
|
||||
role=Role.assistant,
|
||||
content="Sure, decorators allow you to wrap functions.",
|
||||
),
|
||||
]
|
||||
|
||||
run_programmatic(
|
||||
config=cfg,
|
||||
prompt="Can you summarize what decorators are?",
|
||||
output_format=OutputFormat.STREAMING,
|
||||
previous_messages=previous,
|
||||
)
|
||||
|
||||
roles = [r for r, _ in spy.emitted]
|
||||
assert roles == [
|
||||
Role.system,
|
||||
Role.user,
|
||||
Role.assistant,
|
||||
Role.user,
|
||||
Role.assistant,
|
||||
]
|
||||
assert (
|
||||
spy.emitted[0][1] == "You are Vibe, a super useful programming assistant."
|
||||
)
|
||||
assert spy.emitted[1][1] == "Previously, you told me about decorators."
|
||||
assert spy.emitted[2][1] == "Sure, decorators allow you to wrap functions."
|
||||
assert spy.emitted[3][1] == "Can you summarize what decorators are?"
|
||||
assert (
|
||||
spy.emitted[4][1]
|
||||
== "Decorators are wrappers that modify function behavior."
|
||||
)
|
||||
|
||||
|
||||
def test_run_programmatic_ignores_system_messages_in_previous(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
spy = SpyStreamingFormatter()
|
||||
monkeypatch.setattr(
|
||||
"vibe.core.programmatic.create_formatter", lambda *_args, **_kwargs: spy
|
||||
)
|
||||
|
||||
with mock_backend_factory(
|
||||
Backend.MISTRAL,
|
||||
lambda provider, **kwargs: FakeBackend([mock_llm_chunk(content="Understood.")]),
|
||||
):
|
||||
cfg = VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False),
|
||||
system_prompt_id="tests",
|
||||
include_project_context=False,
|
||||
include_prompt_detail=False,
|
||||
include_model_info=False,
|
||||
)
|
||||
|
||||
run_programmatic(
|
||||
config=cfg,
|
||||
prompt="Let's move on to practical examples.",
|
||||
output_format=OutputFormat.STREAMING,
|
||||
previous_messages=[
|
||||
LLMMessage(
|
||||
role=Role.system,
|
||||
content="First system message that should be ignored.",
|
||||
),
|
||||
LLMMessage(role=Role.user, content="Continue our previous discussion."),
|
||||
LLMMessage(
|
||||
role=Role.system,
|
||||
content="Second system message that should be ignored.",
|
||||
),
|
||||
],
|
||||
auto_approve=True,
|
||||
)
|
||||
|
||||
roles = [r for r, _ in spy.emitted]
|
||||
assert roles == [Role.system, Role.user, Role.user, Role.assistant]
|
||||
assert (
|
||||
spy.emitted[0][1] == "You are Vibe, a super useful programming assistant."
|
||||
)
|
||||
assert spy.emitted[1][1] == "Continue our previous discussion."
|
||||
assert spy.emitted[2][1] == "Let's move on to practical examples."
|
||||
assert spy.emitted[3][1] == "Understood."
|
||||
101
tests/test_history_manager.py
Normal file
101
tests/test_history_manager.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from vibe.cli.history_manager import HistoryManager
|
||||
|
||||
|
||||
def test_history_manager_normalizes_loaded_entries_like_numbers_to_strings(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
# ideally, we would not use real I/O; but this test is a quick bugfix, thus it
|
||||
# does not intend to refactor the HistoryManager
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
history_entries = ["hello", 123]
|
||||
history_file.write_text(
|
||||
"\n".join(json.dumps(entry) for entry in history_entries) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
manager = HistoryManager(history_file)
|
||||
|
||||
result = manager.get_previous(current_input="", prefix="1")
|
||||
|
||||
assert result == "123"
|
||||
|
||||
|
||||
def test_history_manager_retains_a_fixed_number_of_entries(tmp_path: Path) -> None:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
manager = HistoryManager(history_file, max_entries=3)
|
||||
|
||||
manager.add("first")
|
||||
manager.add("second")
|
||||
manager.add("third")
|
||||
manager.add("fourth")
|
||||
|
||||
reloaded = HistoryManager(history_file)
|
||||
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "fourth"
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "third"
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "second"
|
||||
# "first" is not proposed as we defined number of entries to 3
|
||||
assert reloaded.get_previous(current_input="", prefix="") is None
|
||||
|
||||
|
||||
def test_history_manager_filters_invalid_and_duplicated_entries(tmp_path: Path) -> None:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
manager = HistoryManager(history_file, max_entries=5)
|
||||
manager.add("") # empty
|
||||
manager.add(" ") # is trimmed
|
||||
manager.add("first")
|
||||
manager.add("second")
|
||||
manager.add("second") # duplicate
|
||||
manager.add("third")
|
||||
|
||||
reloaded = HistoryManager(history_file)
|
||||
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "third"
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "second"
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "first"
|
||||
assert reloaded.get_previous(current_input="", prefix="") is None
|
||||
assert reloaded.get_previous(current_input="", prefix="") is None
|
||||
|
||||
|
||||
def test_history_manager_filters_commands(tmp_path: Path) -> None:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
manager = HistoryManager(history_file, max_entries=5)
|
||||
manager.add("first")
|
||||
manager.add("/skip")
|
||||
|
||||
reloaded = HistoryManager(history_file)
|
||||
|
||||
assert reloaded.get_previous(current_input="", prefix="/") == None
|
||||
assert reloaded.get_previous(current_input="", prefix="") == "first"
|
||||
assert reloaded.get_previous(current_input="", prefix="") is None
|
||||
|
||||
|
||||
def test_history_manager_allows_navigation_round_trip(tmp_path: Path) -> None:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
manager = HistoryManager(history_file)
|
||||
|
||||
manager.add("alpha")
|
||||
manager.add("beta")
|
||||
|
||||
assert manager.get_previous(current_input="typed") == "beta"
|
||||
assert manager.get_previous(current_input="typed") == "alpha"
|
||||
assert manager.get_next() == "beta"
|
||||
assert manager.get_next() == "typed"
|
||||
assert manager.get_next() is None
|
||||
|
||||
|
||||
def test_history_manager_prefix_filtering(tmp_path: Path) -> None:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
manager = HistoryManager(history_file)
|
||||
|
||||
manager.add("foo")
|
||||
manager.add("bar")
|
||||
manager.add("fizz")
|
||||
|
||||
assert manager.get_previous(current_input="", prefix="f") == "fizz"
|
||||
assert manager.get_previous(current_input="", prefix="f") == "foo"
|
||||
assert manager.get_previous(current_input="", prefix="f") is None
|
||||
37
tests/test_system_prompt.py
Normal file
37
tests/test_system_prompt.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.config import VibeConfig
|
||||
from vibe.core.system_prompt import get_universal_system_prompt
|
||||
from vibe.core.tools.manager import ToolManager
|
||||
|
||||
|
||||
def test_get_universal_system_prompt_includes_windows_prompt_on_windows(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(sys, "platform", "win32")
|
||||
monkeypatch.setenv("COMSPEC", "C:\\Windows\\System32\\cmd.exe")
|
||||
|
||||
config = VibeConfig(
|
||||
system_prompt_id="tests",
|
||||
include_project_context=False,
|
||||
include_prompt_detail=True,
|
||||
include_model_info=False,
|
||||
)
|
||||
tool_manager = ToolManager(config)
|
||||
|
||||
prompt = get_universal_system_prompt(tool_manager, config)
|
||||
|
||||
assert "You are Vibe, a super useful programming assistant." in prompt
|
||||
assert (
|
||||
"The operating system is Windows with shell `C:\\Windows\\System32\\cmd.exe`"
|
||||
in prompt
|
||||
)
|
||||
assert "DO NOT use Unix commands like `ls`, `grep`, `cat`" in prompt
|
||||
assert "Use: `dir` (Windows) for directory listings" in prompt
|
||||
assert "Use: backslashes (\\\\) for paths" in prompt
|
||||
assert "Check command availability with: `where command` (Windows)" in prompt
|
||||
assert "Script shebang: Not applicable on Windows" in prompt
|
||||
107
tests/test_tagged_text.py
Normal file
107
tests/test_tagged_text.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.utils import CANCELLATION_TAG, KNOWN_TAGS, TaggedText
|
||||
|
||||
|
||||
def test_tagged_text_creation_without_tag() -> None:
|
||||
tagged = TaggedText("Hello world")
|
||||
assert tagged.message == "Hello world"
|
||||
assert tagged.tag == ""
|
||||
assert str(tagged) == "Hello world"
|
||||
|
||||
|
||||
def test_tagged_text_creation_with_tag() -> None:
|
||||
tagged = TaggedText("User cancelled", CANCELLATION_TAG)
|
||||
assert tagged.message == "User cancelled"
|
||||
assert tagged.tag == CANCELLATION_TAG
|
||||
assert str(tagged) == f"<{CANCELLATION_TAG}>User cancelled</{CANCELLATION_TAG}>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tag", KNOWN_TAGS)
|
||||
def test_tagged_text_from_string_with_known_tag(tag: str) -> None:
|
||||
text = f"<{tag}>This is a tagged text</{tag}>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "This is a tagged text"
|
||||
assert tagged.tag == tag
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tag", KNOWN_TAGS)
|
||||
def test_tagged_text_from_string_with_known_tag_multiline(tag: str) -> None:
|
||||
text = f"<{tag}>This is a tagged text</{tag}>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "This is a tagged text"
|
||||
assert tagged.tag == tag
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tag", KNOWN_TAGS)
|
||||
def test_tagged_text_from_string_with_known_tag_whitespace(tag: str) -> None:
|
||||
text = f"<{tag}> This is a tagged text </{tag}>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == " This is a tagged text "
|
||||
assert tagged.tag == tag
|
||||
|
||||
|
||||
def test_tagged_text_from_string_with_unknown_tag() -> None:
|
||||
text = "<unknown_tag>Some content</unknown_tag>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "<unknown_tag>Some content</unknown_tag>"
|
||||
assert tagged.tag == ""
|
||||
|
||||
|
||||
def test_tagged_text_from_string_with_text_before_tag() -> None:
|
||||
text = f"Prefix text <{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "Prefix text Content"
|
||||
assert tagged.tag == CANCELLATION_TAG
|
||||
|
||||
|
||||
def test_tagged_text_from_string_with_text_after_tag() -> None:
|
||||
text = f"<{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}> Suffix text"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "Content Suffix text"
|
||||
assert tagged.tag == CANCELLATION_TAG
|
||||
|
||||
|
||||
def test_tagged_text_from_string_with_text_before_and_after_tag() -> None:
|
||||
text = f"Before <{CANCELLATION_TAG}>Content</{CANCELLATION_TAG}> After"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "Before Content After"
|
||||
assert tagged.tag == CANCELLATION_TAG
|
||||
|
||||
|
||||
def test_tagged_text_from_string_without_tags() -> None:
|
||||
text = "Just plain text without any tags"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == "Just plain text without any tags"
|
||||
assert tagged.tag == ""
|
||||
|
||||
|
||||
def test_tagged_text_from_string_empty() -> None:
|
||||
tagged = TaggedText.from_string("")
|
||||
assert tagged.message == ""
|
||||
assert tagged.tag == ""
|
||||
|
||||
|
||||
def test_tagged_text_from_string_mismatched_tags() -> None:
|
||||
text = f"<{CANCELLATION_TAG}>Content</different_tag>"
|
||||
tagged = TaggedText.from_string(text)
|
||||
assert tagged.message == f"<{CANCELLATION_TAG}>Content</different_tag>"
|
||||
assert tagged.tag == ""
|
||||
|
||||
|
||||
def test_tagged_text_round_trip() -> None:
|
||||
original = TaggedText("User cancelled", CANCELLATION_TAG)
|
||||
text = str(original)
|
||||
parsed = TaggedText.from_string(text)
|
||||
assert parsed.message == original.message
|
||||
assert parsed.tag == original.tag
|
||||
|
||||
|
||||
def test_tagged_text_round_trip_no_tag() -> None:
|
||||
original = TaggedText("Plain message")
|
||||
text = str(original)
|
||||
parsed = TaggedText.from_string(text)
|
||||
assert parsed.message == original.message
|
||||
assert parsed.tag == original.tag
|
||||
130
tests/test_ui_input_history.py
Normal file
130
tests/test_ui_input_history.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.cli.history_manager import HistoryManager
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.textual_ui.widgets.chat_input.body import ChatInputBody
|
||||
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config() -> VibeConfig:
|
||||
return VibeConfig(session_logging=SessionLoggingConfig(enabled=False))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_app(vibe_config: VibeConfig, tmp_path: Path) -> VibeApp:
|
||||
return VibeApp(config=vibe_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def history_file(tmp_path: Path) -> Path:
|
||||
history_file = tmp_path / "history.jsonl"
|
||||
history_entries = ["hello", "hi there", "how are you?"]
|
||||
history_file.write_text(
|
||||
"\n".join(json.dumps(entry) for entry in history_entries) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return history_file
|
||||
|
||||
|
||||
def inject_history_file(vibe_app: VibeApp, history_file: Path) -> None:
|
||||
# Dependency Injection would help here, but as we don't have it yet: manual injection
|
||||
chat_input_body = vibe_app.query_one(ChatInputBody)
|
||||
chat_input_body.history = HistoryManager(history_file)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_navigation_through_input_history(
|
||||
vibe_app: VibeApp, history_file: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
inject_history_file(vibe_app, history_file)
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
|
||||
await pilot.press("up")
|
||||
assert chat_input.value == "how are you?"
|
||||
await pilot.press("up")
|
||||
assert chat_input.value == "hi there"
|
||||
await pilot.press("up")
|
||||
assert chat_input.value == "hello"
|
||||
await pilot.press("up")
|
||||
# cannot go further up
|
||||
assert chat_input.value == "hello"
|
||||
await pilot.press("down")
|
||||
assert chat_input.value == "hi there"
|
||||
await pilot.press("down")
|
||||
assert chat_input.value == "how are you?"
|
||||
await pilot.press("down")
|
||||
assert chat_input.value == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_does_nothing_if_command_completion_is_active(
|
||||
vibe_app: VibeApp, history_file: Path
|
||||
) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
inject_history_file(vibe_app, history_file)
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
|
||||
await pilot.press("/")
|
||||
assert chat_input.value == "/"
|
||||
await pilot.press("up")
|
||||
assert chat_input.value == "/"
|
||||
await pilot.press("down")
|
||||
assert chat_input.value == "/"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_does_not_prevent_arrow_down_to_move_cursor_to_bottom_lines(
|
||||
vibe_app: VibeApp,
|
||||
):
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
textarea = chat_input.input_widget
|
||||
assert textarea is not None
|
||||
|
||||
await pilot.press(*"test")
|
||||
await pilot.press("ctrl+j", "ctrl+j")
|
||||
assert chat_input.value == "test\n\n"
|
||||
assert textarea.text.count("\n") == 2
|
||||
initial_row = textarea.cursor_location[0]
|
||||
assert initial_row == 2, f"Expected cursor on line 2, got line {initial_row}"
|
||||
await pilot.press("up")
|
||||
assert textarea.cursor_location[0] == 1, "First arrow up should move to line 1"
|
||||
await pilot.press("up")
|
||||
assert textarea.cursor_location[0] == 0, (
|
||||
"Second arrow up should move to line 0 (first line)"
|
||||
)
|
||||
await pilot.press("down")
|
||||
final_row = textarea.cursor_location[0]
|
||||
assert final_row == 1, f"cursor is still on line {final_row}."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_resumes_arrow_down_after_manual_move(
|
||||
vibe_app: VibeApp, tmp_path: Path
|
||||
) -> None:
|
||||
history_path = tmp_path / "history.jsonl"
|
||||
history_path.write_text(
|
||||
json.dumps("first line\nsecond line") + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
async with vibe_app.run_test() as pilot:
|
||||
inject_history_file(vibe_app, history_path)
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
textarea = chat_input.input_widget
|
||||
assert textarea is not None
|
||||
|
||||
await pilot.press("up")
|
||||
assert chat_input.value == "first line\nsecond line"
|
||||
assert textarea.cursor_location == (0, len("first line"))
|
||||
await pilot.press("left")
|
||||
await pilot.press("down")
|
||||
assert textarea.cursor_location[0] == 1
|
||||
assert chat_input.value == "first line\nsecond line"
|
||||
161
tests/test_ui_pending_user_message.py
Normal file
161
tests/test_ui_pending_user_message.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
|
||||
from vibe.cli.textual_ui.widgets.messages import InterruptMessage, UserMessage
|
||||
from vibe.core.agent import Agent
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
from vibe.core.types import BaseEvent
|
||||
|
||||
|
||||
async def _wait_for(
|
||||
pilot, condition: Callable[[], object | None], timeout: float = 3.0
|
||||
) -> object | None:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
result = condition()
|
||||
if result:
|
||||
return result
|
||||
await pilot.pause(0.05)
|
||||
return None
|
||||
|
||||
|
||||
class StubAgent(Agent):
|
||||
def __init__(self) -> None:
|
||||
self.messages: list = []
|
||||
self.stats = SimpleNamespace(context_tokens=0)
|
||||
self.approval_callback = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
return
|
||||
|
||||
async def act(self, msg: str) -> AsyncGenerator[BaseEvent]:
|
||||
if False:
|
||||
yield msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config() -> VibeConfig:
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False), enable_update_checks=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
|
||||
return VibeApp(config=vibe_config)
|
||||
|
||||
|
||||
def _patch_delayed_init(
|
||||
monkeypatch: pytest.MonkeyPatch, init_event: asyncio.Event
|
||||
) -> None:
|
||||
async def _fake_initialize(self: VibeApp) -> None:
|
||||
if self.agent or self._agent_initializing:
|
||||
return
|
||||
|
||||
self._agent_initializing = True
|
||||
try:
|
||||
await init_event.wait()
|
||||
self.agent = StubAgent()
|
||||
except asyncio.CancelledError:
|
||||
self.agent = None
|
||||
return
|
||||
finally:
|
||||
self._agent_initializing = False
|
||||
self._agent_init_task = None
|
||||
|
||||
monkeypatch.setattr(VibeApp, "_initialize_agent", _fake_initialize, raising=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shows_user_message_as_pending_until_agent_is_initialized(
|
||||
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
init_event = asyncio.Event()
|
||||
_patch_delayed_init(monkeypatch, init_event)
|
||||
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "Hello"
|
||||
|
||||
press_task = asyncio.create_task(pilot.press("enter"))
|
||||
|
||||
user_message = await _wait_for(
|
||||
pilot, lambda: next(iter(vibe_app.query(UserMessage)), None)
|
||||
)
|
||||
assert isinstance(user_message, UserMessage)
|
||||
assert user_message.has_class("pending")
|
||||
init_event.set()
|
||||
await press_task
|
||||
assert not user_message.has_class("pending")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_interrupt_pending_message_during_initialization(
|
||||
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
init_event = asyncio.Event()
|
||||
_patch_delayed_init(monkeypatch, init_event)
|
||||
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "Hello"
|
||||
|
||||
press_task = asyncio.create_task(pilot.press("enter"))
|
||||
|
||||
user_message = await _wait_for(
|
||||
pilot, lambda: next(iter(vibe_app.query(UserMessage)), None)
|
||||
)
|
||||
assert isinstance(user_message, UserMessage)
|
||||
assert user_message.has_class("pending")
|
||||
|
||||
await pilot.press("escape")
|
||||
await press_task
|
||||
assert not user_message.has_class("pending")
|
||||
assert vibe_app.query(InterruptMessage)
|
||||
assert vibe_app.agent is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_initialization_after_interrupt(
|
||||
vibe_app: VibeApp, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
init_event = asyncio.Event()
|
||||
_patch_delayed_init(monkeypatch, init_event)
|
||||
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "First Message"
|
||||
press_task = asyncio.create_task(pilot.press("enter"))
|
||||
|
||||
await _wait_for(pilot, lambda: next(iter(vibe_app.query(UserMessage)), None))
|
||||
await pilot.press("escape")
|
||||
await press_task
|
||||
assert vibe_app.agent is None
|
||||
assert vibe_app._agent_init_task is None
|
||||
|
||||
chat_input.value = "Second Message"
|
||||
press_task_2 = asyncio.create_task(pilot.press("enter"))
|
||||
|
||||
def get_second_message():
|
||||
messages = list(vibe_app.query(UserMessage))
|
||||
if len(messages) >= 2:
|
||||
return messages[-1]
|
||||
return None
|
||||
|
||||
user_message_2 = await _wait_for(pilot, get_second_message)
|
||||
assert isinstance(user_message_2, UserMessage)
|
||||
assert user_message_2.has_class("pending")
|
||||
assert vibe_app.agent is None
|
||||
|
||||
init_event.set()
|
||||
await press_task_2
|
||||
assert not user_message_2.has_class("pending")
|
||||
assert vibe_app.agent is not None
|
||||
86
tests/tools/test_bash.py
Normal file
86
tests/tools/test_bash.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.tools.base import BaseToolState, ToolError, ToolPermission
|
||||
from vibe.core.tools.builtins.bash import Bash, BashArgs, BashToolConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bash(tmp_path):
|
||||
config = BashToolConfig(workdir=tmp_path)
|
||||
return Bash(config=config, state=BaseToolState())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runs_echo_successfully(bash):
|
||||
result = await bash.run(BashArgs(command="echo hello"))
|
||||
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == "hello\n"
|
||||
assert result.stderr == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_cat_command_with_missing_file(bash):
|
||||
with pytest.raises(ToolError) as err:
|
||||
await bash.run(BashArgs(command="cat missing_file.txt"))
|
||||
|
||||
message = str(err.value)
|
||||
assert "Command failed" in message
|
||||
assert "Return code: 1" in message
|
||||
assert "No such file or directory" in message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_effective_workdir(tmp_path):
|
||||
config = BashToolConfig(workdir=tmp_path)
|
||||
bash_tool = Bash(config=config, state=BaseToolState())
|
||||
|
||||
result = await bash_tool.run(BashArgs(command="pwd"))
|
||||
|
||||
assert result.stdout.strip() == str(tmp_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_timeout(bash):
|
||||
with pytest.raises(ToolError) as err:
|
||||
await bash.run(BashArgs(command="sleep 2", timeout=1))
|
||||
|
||||
assert "Command timed out after 1s" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_output_to_max_bytes(bash):
|
||||
config = BashToolConfig(workdir=None, max_output_bytes=5)
|
||||
bash_tool = Bash(config=config, state=BaseToolState())
|
||||
|
||||
result = await bash_tool.run(BashArgs(command="printf 'abcdefghij'"))
|
||||
|
||||
assert result.stdout == "abcde"
|
||||
assert result.stderr == ""
|
||||
assert result.returncode == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decodes_non_utf8_bytes(bash):
|
||||
result = await bash.run(BashArgs(command="printf '\\xff\\xfe'"))
|
||||
|
||||
# accept both possible encodings, as some shells emit escaped bytes as literal strings
|
||||
assert result.stdout in {"<EFBFBD><EFBFBD>", "\xff\xfe", r"\xff\xfe"}
|
||||
assert result.stderr == ""
|
||||
|
||||
|
||||
def test_check_allowlist_denylist():
|
||||
config = BashToolConfig(allowlist=["echo", "pwd"], denylist=["rm"])
|
||||
bash_tool = Bash(config=config, state=BaseToolState())
|
||||
|
||||
allowlisted = bash_tool.check_allowlist_denylist(BashArgs(command="echo hi"))
|
||||
denylisted = bash_tool.check_allowlist_denylist(BashArgs(command="rm -rf /tmp"))
|
||||
mixed = bash_tool.check_allowlist_denylist(BashArgs(command="pwd && whoami"))
|
||||
empty = bash_tool.check_allowlist_denylist(BashArgs(command=""))
|
||||
|
||||
assert allowlisted is ToolPermission.ALWAYS
|
||||
assert denylisted is ToolPermission.NEVER
|
||||
assert mixed is None
|
||||
assert empty is None
|
||||
347
tests/tools/test_grep.py
Normal file
347
tests/tools/test_grep.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.grep import (
|
||||
Grep,
|
||||
GrepArgs,
|
||||
GrepBackend,
|
||||
GrepState,
|
||||
GrepToolConfig,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grep(tmp_path):
|
||||
config = GrepToolConfig(workdir=tmp_path)
|
||||
return Grep(config=config, state=GrepState())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def grep_gnu_only(tmp_path, monkeypatch):
|
||||
original_which = shutil.which
|
||||
|
||||
def mock_which(cmd):
|
||||
if cmd == "rg":
|
||||
return None
|
||||
return original_which(cmd)
|
||||
|
||||
monkeypatch.setattr("shutil.which", mock_which)
|
||||
config = GrepToolConfig(workdir=tmp_path)
|
||||
return Grep(config=config, state=GrepState())
|
||||
|
||||
|
||||
def test_detects_ripgrep_when_available(grep):
|
||||
if shutil.which("rg"):
|
||||
assert grep._detect_backend() == GrepBackend.RIPGREP
|
||||
|
||||
|
||||
def test_falls_back_to_gnu_grep(grep, monkeypatch):
|
||||
original_which = shutil.which
|
||||
|
||||
def mock_which(cmd):
|
||||
if cmd == "rg":
|
||||
return None
|
||||
return original_which(cmd)
|
||||
|
||||
monkeypatch.setattr("shutil.which", mock_which)
|
||||
|
||||
if shutil.which("grep"):
|
||||
assert grep._detect_backend() == GrepBackend.GNU_GREP
|
||||
|
||||
|
||||
def test_raises_error_if_no_grep_available(grep, monkeypatch):
|
||||
monkeypatch.setattr("shutil.which", lambda cmd: None)
|
||||
|
||||
with pytest.raises(ToolError) as err:
|
||||
grep._detect_backend()
|
||||
|
||||
assert "Neither ripgrep (rg) nor grep is installed" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_pattern_in_file(grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("def hello():\n print('world')\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="hello"))
|
||||
|
||||
assert result.match_count == 1
|
||||
assert "hello" in result.matches
|
||||
assert "test.py" in result.matches
|
||||
assert not result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_multiple_matches(grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("foo\nbar\nfoo\nbaz\nfoo\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="foo"))
|
||||
|
||||
assert result.match_count == 3
|
||||
assert result.matches.count("foo") == 3
|
||||
assert not result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_on_no_matches(grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("def hello():\n pass\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="nonexistent"))
|
||||
|
||||
assert result.match_count == 0
|
||||
assert result.matches == ""
|
||||
assert not result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_with_empty_pattern(grep):
|
||||
with pytest.raises(ToolError) as err:
|
||||
await grep.run(GrepArgs(pattern=""))
|
||||
|
||||
assert "Empty search pattern" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fails_with_nonexistent_path(grep):
|
||||
with pytest.raises(ToolError) as err:
|
||||
await grep.run(GrepArgs(pattern="test", path="nonexistent"))
|
||||
|
||||
assert "Path does not exist" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_in_specific_path(grep, tmp_path):
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
(subdir / "test.py").write_text("match here\n")
|
||||
(tmp_path / "other.py").write_text("match here too\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="match", path="subdir"))
|
||||
|
||||
assert result.match_count == 1
|
||||
assert "subdir" in result.matches and "test.py" in result.matches
|
||||
assert "other.py" not in result.matches
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_to_max_matches(grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("\n".join(f"line {i}" for i in range(200)))
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="line", max_matches=50))
|
||||
|
||||
assert result.match_count == 50
|
||||
assert result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_to_max_output_bytes(grep, tmp_path):
|
||||
config = GrepToolConfig(workdir=tmp_path, max_output_bytes=100)
|
||||
grep_tool = Grep(config=config, state=GrepState())
|
||||
(tmp_path / "test.py").write_text("\n".join("x" * 100 for _ in range(10)))
|
||||
|
||||
result = await grep_tool.run(GrepArgs(pattern="x"))
|
||||
|
||||
assert len(result.matches) <= 100
|
||||
assert result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_default_ignore_patterns(grep, tmp_path):
|
||||
(tmp_path / "included.py").write_text("match\n")
|
||||
node_modules = tmp_path / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "excluded.js").write_text("match\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="match"))
|
||||
|
||||
assert "included.py" in result.matches
|
||||
assert "excluded.js" not in result.matches
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_vibeignore_file(grep, tmp_path):
|
||||
(tmp_path / ".vibeignore").write_text("custom_dir/\n*.tmp\n")
|
||||
custom_dir = tmp_path / "custom_dir"
|
||||
custom_dir.mkdir()
|
||||
(custom_dir / "excluded.py").write_text("match\n")
|
||||
(tmp_path / "excluded.tmp").write_text("match\n")
|
||||
(tmp_path / "included.py").write_text("match\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="match"))
|
||||
|
||||
assert "included.py" in result.matches
|
||||
assert "excluded.py" not in result.matches
|
||||
assert "excluded.tmp" not in result.matches
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_comments_in_vibeignore(grep, tmp_path):
|
||||
(tmp_path / ".vibeignore").write_text("# comment\npattern/\n# another comment\n")
|
||||
(tmp_path / "file.py").write_text("match\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="match"))
|
||||
|
||||
assert result.match_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_search_history(grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("content\n")
|
||||
|
||||
await grep.run(GrepArgs(pattern="first"))
|
||||
await grep.run(GrepArgs(pattern="second"))
|
||||
await grep.run(GrepArgs(pattern="third"))
|
||||
|
||||
assert grep.state.search_history == ["first", "second", "third"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_effective_workdir(tmp_path):
|
||||
config = GrepToolConfig(workdir=tmp_path)
|
||||
grep_tool = Grep(config=config, state=GrepState())
|
||||
(tmp_path / "test.py").write_text("match\n")
|
||||
|
||||
result = await grep_tool.run(GrepArgs(pattern="match", path="."))
|
||||
|
||||
assert result.match_count == 1
|
||||
assert "test.py" in result.matches
|
||||
|
||||
|
||||
@pytest.mark.skipif(not shutil.which("grep"), reason="GNU grep not available")
|
||||
class TestGnuGrepBackend:
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_pattern_in_file(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "test.py").write_text("def hello():\n print('world')\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="hello"))
|
||||
|
||||
assert result.match_count == 1
|
||||
assert "hello" in result.matches
|
||||
assert "test.py" in result.matches
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finds_multiple_matches(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "test.py").write_text("foo\nbar\nfoo\nbaz\nfoo\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="foo"))
|
||||
|
||||
assert result.match_count == 3
|
||||
assert result.matches.count("foo") == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_on_no_matches(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "test.py").write_text("def hello():\n pass\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="nonexistent"))
|
||||
|
||||
assert result.match_count == 0
|
||||
assert result.matches == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_insensitive_for_lowercase_pattern(
|
||||
self, grep_gnu_only, tmp_path
|
||||
):
|
||||
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="hello"))
|
||||
|
||||
assert result.match_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_case_sensitive_for_mixed_case_pattern(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="Hello"))
|
||||
|
||||
assert result.match_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_exclude_patterns(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "included.py").write_text("match\n")
|
||||
node_modules = tmp_path / "node_modules"
|
||||
node_modules.mkdir()
|
||||
(node_modules / "excluded.js").write_text("match\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="match"))
|
||||
|
||||
assert "included.py" in result.matches
|
||||
assert "excluded.js" not in result.matches
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_in_specific_path(self, grep_gnu_only, tmp_path):
|
||||
subdir = tmp_path / "subdir"
|
||||
subdir.mkdir()
|
||||
(subdir / "test.py").write_text("match here\n")
|
||||
(tmp_path / "other.py").write_text("match here too\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="match", path="subdir"))
|
||||
|
||||
assert result.match_count == 1
|
||||
assert "other.py" not in result.matches
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_vibeignore_file(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / ".vibeignore").write_text("custom_dir/\n*.tmp\n")
|
||||
custom_dir = tmp_path / "custom_dir"
|
||||
custom_dir.mkdir()
|
||||
(custom_dir / "excluded.py").write_text("match\n")
|
||||
(tmp_path / "excluded.tmp").write_text("match\n")
|
||||
(tmp_path / "included.py").write_text("match\n")
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="match"))
|
||||
|
||||
assert "included.py" in result.matches
|
||||
assert "excluded.py" not in result.matches
|
||||
assert "excluded.tmp" not in result.matches
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_to_max_matches(self, grep_gnu_only, tmp_path):
|
||||
(tmp_path / "test.py").write_text("\n".join(f"line {i}" for i in range(200)))
|
||||
|
||||
result = await grep_gnu_only.run(GrepArgs(pattern="line", max_matches=50))
|
||||
|
||||
assert result.match_count == 50
|
||||
assert result.was_truncated
|
||||
|
||||
|
||||
@pytest.mark.skipif(not shutil.which("rg"), reason="ripgrep not available")
|
||||
class TestRipgrepBackend:
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_case_lowercase_pattern(self, grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="hello"))
|
||||
|
||||
assert result.match_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_case_mixed_case_pattern(self, grep, tmp_path):
|
||||
(tmp_path / "test.py").write_text("Hello\nHELLO\nhello\n")
|
||||
|
||||
result = await grep.run(GrepArgs(pattern="Hello"))
|
||||
|
||||
assert result.match_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_ignored_files_when_use_default_ignore_false(
|
||||
self, grep, tmp_path
|
||||
):
|
||||
(tmp_path / ".ignore").write_text("ignored_by_rg/\n")
|
||||
|
||||
ignored_dir = tmp_path / "ignored_by_rg"
|
||||
ignored_dir.mkdir()
|
||||
(ignored_dir / "file.py").write_text("match\n")
|
||||
(tmp_path / "included.py").write_text("match\n")
|
||||
|
||||
result_with_ignore = await grep.run(GrepArgs(pattern="match"))
|
||||
assert "included.py" in result_with_ignore.matches
|
||||
assert "ignored_by_rg" not in result_with_ignore.matches
|
||||
|
||||
result_without_ignore = await grep.run(
|
||||
GrepArgs(pattern="match", use_default_ignore=False)
|
||||
)
|
||||
assert "included.py" in result_without_ignore.matches
|
||||
assert "ignored_by_rg/file.py" in result_without_ignore.matches
|
||||
138
tests/tools/test_ui_bash_execution.py
Normal file
138
tests/tools/test_ui_bash_execution.py
Normal file
@@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from textual.widgets import Static
|
||||
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.textual_ui.widgets.chat_input.container import ChatInputContainer
|
||||
from vibe.cli.textual_ui.widgets.messages import BashOutputMessage, ErrorMessage
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config(tmp_path: Path) -> VibeConfig:
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False), workdir=tmp_path
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_app(vibe_config: VibeConfig) -> VibeApp:
|
||||
return VibeApp(config=vibe_config)
|
||||
|
||||
|
||||
async def _wait_for_bash_output_message(
|
||||
vibe_app: VibeApp, pilot, timeout: float = 1.0
|
||||
) -> BashOutputMessage:
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
if message := next(iter(vibe_app.query(BashOutputMessage)), None):
|
||||
return message
|
||||
await pilot.pause(0.05)
|
||||
raise TimeoutError(f"BashOutputMessage did not appear within {timeout}s")
|
||||
|
||||
|
||||
def assert_no_command_error(vibe_app: VibeApp) -> None:
|
||||
errors = list(vibe_app.query(ErrorMessage))
|
||||
if not errors:
|
||||
return
|
||||
|
||||
disallowed = {
|
||||
"Command failed",
|
||||
"Command timed out",
|
||||
"No command provided after '!'",
|
||||
}
|
||||
offending = [
|
||||
getattr(err, "_error", "")
|
||||
for err in errors
|
||||
if getattr(err, "_error", "")
|
||||
and any(phrase in getattr(err, "_error", "") for phrase in disallowed)
|
||||
]
|
||||
assert not offending, f"Unexpected command errors: {offending}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_reports_no_output(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!true"
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
output_widget = message.query_one(".bash-output", Static)
|
||||
assert str(output_widget.render()) == "(no output)"
|
||||
assert_no_command_error(vibe_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_shows_success_in_case_of_zero_code(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!true"
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
icon = message.query_one(".bash-exit-success", Static)
|
||||
assert str(icon.render()) == "✓"
|
||||
assert not list(message.query(".bash-exit-failure"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_shows_failure_in_case_of_non_zero_code(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!bash -lc 'exit 7'"
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
icon = message.query_one(".bash-exit-failure", Static)
|
||||
assert str(icon.render()) == "✗"
|
||||
code = message.query_one(".bash-exit-code", Static)
|
||||
assert "7" in str(code.render())
|
||||
assert not list(message.query(".bash-exit-success"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_handles_non_utf8_output(vibe_app: VibeApp) -> None:
|
||||
"""Assert the UI accepts decoding a non-UTF8 sequence like `printf '\xf0\x9f\x98'`.
|
||||
Whereas `printf '\xf0\x9f\x98\x8b'` prints a smiley face (😋) and would work even without those changes.
|
||||
"""
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!printf '\\xff\\xfe'"
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
output_widget = message.query_one(".bash-output", Static)
|
||||
# accept both possible encodings, as some shells emit escaped bytes as literal strings
|
||||
assert str(output_widget.render()) in {"<EFBFBD><EFBFBD>", "\xff\xfe", r"\xff\xfe"}
|
||||
assert_no_command_error(vibe_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_handles_utf8_output(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!echo hello"
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
output_widget = message.query_one(".bash-output", Static)
|
||||
assert str(output_widget.render()) == "hello\n"
|
||||
assert_no_command_error(vibe_app)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_handles_non_utf8_stderr(vibe_app: VibeApp) -> None:
|
||||
async with vibe_app.run_test() as pilot:
|
||||
chat_input = vibe_app.query_one(ChatInputContainer)
|
||||
chat_input.value = "!bash -lc \"printf '\\\\xff\\\\xfe' 1>&2\""
|
||||
|
||||
await pilot.press("enter")
|
||||
message = await _wait_for_bash_output_message(vibe_app, pilot)
|
||||
output_widget = message.query_one(".bash-output", Static)
|
||||
assert str(output_widget.render()) == "<EFBFBD><EFBFBD>"
|
||||
assert_no_command_error(vibe_app)
|
||||
249
tests/update_notifier/test_github_version_update_gateway.py
Normal file
249
tests/update_notifier/test_github_version_update_gateway.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from vibe.cli.update_notifier.github_version_update_gateway import (
|
||||
GitHubVersionUpdateGateway,
|
||||
)
|
||||
from vibe.cli.update_notifier.version_update_gateway import (
|
||||
VersionUpdateGatewayCause,
|
||||
VersionUpdateGatewayError,
|
||||
)
|
||||
|
||||
Handler = Callable[[httpx.Request], httpx.Response]
|
||||
|
||||
GITHUB_API_URL = "https://api.github.com"
|
||||
|
||||
|
||||
def _raise_connect_timeout(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectTimeout("boom", request=request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_latest_version_when_available() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
assert request.headers.get("Authorization") == "Bearer token"
|
||||
return httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
json=[{"tag_name": "v1.2.3", "prerelease": False, "draft": False}],
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway(
|
||||
"owner", "repo", token="token", client=client
|
||||
)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is not None
|
||||
assert update.latest_version == "1.2.3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_uppercase_prefix_from_tag_name() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
json=[{"tag_name": "V0.9.0", "prerelease": False, "draft": False}],
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is not None
|
||||
assert update.latest_version == "0.9.0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_considers_no_update_available_when_no_releases_are_found() -> None:
|
||||
"""If the repository cannot be accessed (e.g. invalid token), the response will be 404.
|
||||
But using API 'releases/latest', if no release has been created, the response will ALSO be 404.
|
||||
|
||||
This test ensures that we consider no update available when no releases are found.
|
||||
(And this is why we are using "releases" with a per_page=1 parameter, instead of "releases/latest")
|
||||
"""
|
||||
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(status_code=httpx.codes.OK, json=[])
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_considers_no_update_available_when_only_drafts_and_prereleases_are_found() -> (
|
||||
None
|
||||
):
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
json=[
|
||||
{"tag_name": "v2.0.0-beta", "prerelease": True, "draft": False},
|
||||
{"tag_name": "v2.0.0", "prerelease": False, "draft": True},
|
||||
],
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_picks_the_most_recently_published_non_prerelease_and_non_draft() -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=httpx.codes.OK,
|
||||
json=[
|
||||
{
|
||||
"tag_name": "v2.0.0-beta",
|
||||
"prerelease": True,
|
||||
"draft": False,
|
||||
"published_at": "2025-10-25T112:00:00Z",
|
||||
},
|
||||
{
|
||||
"tag_name": "v2.0.0",
|
||||
"prerelease": False,
|
||||
"draft": True,
|
||||
"published_at": "2025-10-26T112:00:00Z",
|
||||
},
|
||||
{
|
||||
"tag_name": "v1.12.455",
|
||||
"prerelease": False,
|
||||
"draft": False,
|
||||
"published_at": "2025-11-02T112:00:00Z",
|
||||
},
|
||||
{
|
||||
"tag_name": "1.12.400",
|
||||
"prerelease": False,
|
||||
"draft": False,
|
||||
"published_at": "2025-11-10T112:00:00Z",
|
||||
},
|
||||
{
|
||||
"tag_name": "1.12.300",
|
||||
"prerelease": False,
|
||||
"draft": False,
|
||||
"published_at": "2025-11-11T112:00:00Z",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is not None
|
||||
assert update.latest_version == "1.12.300"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"payload",
|
||||
[
|
||||
[{"tag_name": "v2.0.0-beta", "prerelease": True, "draft": False}],
|
||||
[{"tag_name": "v2.0.0", "prerelease": False, "draft": True}],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_draft_releases_and_prereleases(
|
||||
payload: dict[str, object],
|
||||
) -> None:
|
||||
def handler(request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(status_code=httpx.codes.OK, json=payload)
|
||||
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
update = await notifier.fetch_update()
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("handler", "expected_cause", "expected_custom_message"),
|
||||
[
|
||||
(
|
||||
lambda _: httpx.Response(status_code=httpx.codes.NOT_FOUND),
|
||||
VersionUpdateGatewayCause.NOT_FOUND,
|
||||
"Unable to fetch the GitHub releases. Did you export a GITHUB_TOKEN environment variable?",
|
||||
),
|
||||
(
|
||||
lambda _: httpx.Response(
|
||||
status_code=httpx.codes.FORBIDDEN,
|
||||
headers={"X-RateLimit-Remaining": "0"},
|
||||
),
|
||||
VersionUpdateGatewayCause.TOO_MANY_REQUESTS,
|
||||
None,
|
||||
),
|
||||
(
|
||||
lambda _: httpx.Response(status_code=httpx.codes.TOO_MANY_REQUESTS),
|
||||
VersionUpdateGatewayCause.TOO_MANY_REQUESTS,
|
||||
None,
|
||||
),
|
||||
(
|
||||
lambda _: httpx.Response(status_code=httpx.codes.FORBIDDEN),
|
||||
VersionUpdateGatewayCause.FORBIDDEN,
|
||||
None,
|
||||
),
|
||||
(
|
||||
lambda _: httpx.Response(status_code=httpx.codes.INTERNAL_SERVER_ERROR),
|
||||
VersionUpdateGatewayCause.ERROR_RESPONSE,
|
||||
None,
|
||||
),
|
||||
(
|
||||
lambda _: httpx.Response(status_code=httpx.codes.OK, text="not json"),
|
||||
VersionUpdateGatewayCause.INVALID_RESPONSE,
|
||||
None,
|
||||
),
|
||||
(_raise_connect_timeout, VersionUpdateGatewayCause.REQUEST_FAILED, None),
|
||||
],
|
||||
ids=[
|
||||
"not_found",
|
||||
"rate_limit_header",
|
||||
"rate_limit_status",
|
||||
"forbidden",
|
||||
"error_response",
|
||||
"invalid_json",
|
||||
"request_error",
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_fetching_update_fails(
|
||||
handler: Handler,
|
||||
expected_cause: VersionUpdateGatewayCause,
|
||||
expected_custom_message: str | None,
|
||||
) -> None:
|
||||
transport = httpx.MockTransport(handler)
|
||||
async with httpx.AsyncClient(
|
||||
transport=transport, base_url=GITHUB_API_URL
|
||||
) as client:
|
||||
notifier = GitHubVersionUpdateGateway("owner", "repo", client=client)
|
||||
with pytest.raises(VersionUpdateGatewayError) as excinfo:
|
||||
await notifier.fetch_update()
|
||||
|
||||
assert excinfo.value.cause == expected_cause
|
||||
if expected_custom_message is not None:
|
||||
assert str(excinfo.value) == expected_custom_message
|
||||
161
tests/update_notifier/test_ui_version_update_notification.py
Normal file
161
tests/update_notifier/test_ui_version_update_notification.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Protocol
|
||||
|
||||
import pytest
|
||||
from textual.app import Notification
|
||||
|
||||
from vibe.cli.textual_ui.app import VibeApp
|
||||
from vibe.cli.update_notifier.fake_version_update_gateway import (
|
||||
FakeVersionUpdateGateway,
|
||||
)
|
||||
from vibe.cli.update_notifier.version_update_gateway import (
|
||||
VersionUpdate,
|
||||
VersionUpdateGatewayCause,
|
||||
VersionUpdateGatewayError,
|
||||
)
|
||||
from vibe.core.config import SessionLoggingConfig, VibeConfig
|
||||
|
||||
|
||||
async def _wait_for_notification(
|
||||
app: VibeApp, pilot, *, timeout: float = 1.0, interval: float = 0.05
|
||||
) -> Notification:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
|
||||
while loop.time() < deadline:
|
||||
notifications = list(app._notifications)
|
||||
if notifications:
|
||||
return notifications[-1]
|
||||
await pilot.pause(interval)
|
||||
|
||||
pytest.fail("Notification not displayed")
|
||||
|
||||
|
||||
async def _assert_no_notifications(
|
||||
app: VibeApp, pilot, *, timeout: float = 1.0, interval: float = 0.05
|
||||
) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
deadline = loop.time() + timeout
|
||||
|
||||
while loop.time() < deadline:
|
||||
if app._notifications:
|
||||
pytest.fail("Notification unexpectedly displayed")
|
||||
await pilot.pause(interval)
|
||||
|
||||
assert not app._notifications
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vibe_config_with_update_checks_enabled() -> VibeConfig:
|
||||
return VibeConfig(
|
||||
session_logging=SessionLoggingConfig(enabled=False), enable_update_checks=True
|
||||
)
|
||||
|
||||
|
||||
class VibeAppFactory(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
notifier: FakeVersionUpdateGateway,
|
||||
config: VibeConfig | None = None,
|
||||
auto_approve: bool = False,
|
||||
current_version: str = "0.1.0",
|
||||
) -> VibeApp: ...
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_vibe_app(vibe_config_with_update_checks_enabled: VibeConfig) -> VibeAppFactory:
|
||||
def _make_app(
|
||||
*,
|
||||
notifier: FakeVersionUpdateGateway,
|
||||
config: VibeConfig | None = None,
|
||||
auto_approve: bool = False,
|
||||
current_version: str = "0.1.0",
|
||||
) -> VibeApp:
|
||||
return VibeApp(
|
||||
config=config or vibe_config_with_update_checks_enabled,
|
||||
auto_approve=auto_approve,
|
||||
version_update_notifier=notifier,
|
||||
current_version=current_version,
|
||||
)
|
||||
|
||||
return _make_app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_displays_update_notification(make_vibe_app: VibeAppFactory) -> None:
|
||||
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
|
||||
app = make_vibe_app(notifier=notifier)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
notification = await _wait_for_notification(app, pilot, timeout=0.3)
|
||||
|
||||
assert notification.severity == "information"
|
||||
assert notification.title == "Update available"
|
||||
assert (
|
||||
notification.message
|
||||
== '0.1.0 => 0.2.0\nRun "uv tool upgrade mistral-vibe" to update'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_does_not_display_update_notification_when_not_available(
|
||||
make_vibe_app: VibeAppFactory,
|
||||
) -> None:
|
||||
notifier = FakeVersionUpdateGateway(update=None)
|
||||
app = make_vibe_app(notifier=notifier)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await _assert_no_notifications(app, pilot, timeout=0.3)
|
||||
assert notifier.fetch_update_calls == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_displays_warning_toast_when_check_fails(
|
||||
make_vibe_app: VibeAppFactory,
|
||||
) -> None:
|
||||
notifier = FakeVersionUpdateGateway(
|
||||
error=VersionUpdateGatewayError(cause=VersionUpdateGatewayCause.FORBIDDEN)
|
||||
)
|
||||
app = make_vibe_app(notifier=notifier)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await pilot.pause(0.3)
|
||||
notifications = list(app._notifications)
|
||||
|
||||
assert notifications
|
||||
warning = notifications[-1]
|
||||
assert warning.severity == "warning"
|
||||
assert "forbidden" in warning.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_does_not_invoke_gateway_nor_show_error_notification_when_update_checks_are_disabled(
|
||||
vibe_config_with_update_checks_enabled: VibeConfig, make_vibe_app: VibeAppFactory
|
||||
) -> None:
|
||||
config = vibe_config_with_update_checks_enabled
|
||||
config.enable_update_checks = False
|
||||
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
|
||||
app = make_vibe_app(notifier=notifier, config=config)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await _assert_no_notifications(app, pilot, timeout=0.3)
|
||||
|
||||
assert notifier.fetch_update_calls == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ui_does_not_invoke_gateway_nor_show_update_notification_when_update_checks_are_disabled(
|
||||
vibe_config_with_update_checks_enabled: VibeConfig, make_vibe_app: VibeAppFactory
|
||||
) -> None:
|
||||
config = vibe_config_with_update_checks_enabled
|
||||
config.enable_update_checks = False
|
||||
notifier = FakeVersionUpdateGateway(update=VersionUpdate(latest_version="0.2.0"))
|
||||
app = make_vibe_app(notifier=notifier, config=config)
|
||||
|
||||
async with app.run_test() as pilot:
|
||||
await _assert_no_notifications(app, pilot, timeout=0.3)
|
||||
|
||||
assert notifier.fetch_update_calls == 0
|
||||
146
tests/update_notifier/test_version_update_use_case.py
Normal file
146
tests/update_notifier/test_version_update_use_case.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from vibe.cli.update_notifier.fake_version_update_gateway import (
|
||||
FakeVersionUpdateGateway,
|
||||
)
|
||||
from vibe.cli.update_notifier.version_update import (
|
||||
VersionUpdateError,
|
||||
is_version_update_available,
|
||||
)
|
||||
from vibe.cli.update_notifier.version_update_gateway import (
|
||||
VersionUpdate,
|
||||
VersionUpdateGatewayCause,
|
||||
VersionUpdateGatewayError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_the_latest_version_update_when_available() -> None:
|
||||
latest_update = "1.0.3"
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version=latest_update)
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version="1.0.0"
|
||||
)
|
||||
|
||||
assert update is not None
|
||||
assert update.latest_version == latest_update
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_the_current_version_is_the_latest() -> None:
|
||||
current_version = "1.0.0"
|
||||
latest_version = "1.0.0"
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version=latest_version)
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version=current_version
|
||||
)
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_the_current_version_is_greater_than_the_latest() -> (
|
||||
None
|
||||
):
|
||||
current_version = "0.2.0"
|
||||
latest_version = "0.1.2"
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version=latest_version)
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version=current_version
|
||||
)
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_no_version_is_available() -> None:
|
||||
version_update_notifier = FakeVersionUpdateGateway(update=None)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version="1.0.0"
|
||||
)
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_latest_version_is_invalid() -> None:
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version="invalid-version")
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version="1.0.0"
|
||||
)
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replaces_hyphens_with_plus_signs_in_latest_version_to_conform_with_PEP_440() -> (
|
||||
None
|
||||
):
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
# if we were not replacing hyphens with plus signs, this should fail for PEP 440
|
||||
update=VersionUpdate(latest_version="1.6.1-jetbrains")
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version="1.0.0"
|
||||
)
|
||||
|
||||
assert update is not None
|
||||
assert update.latest_version == "1.6.1-jetbrains"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieves_nothing_when_current_version_is_invalid() -> None:
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
update=VersionUpdate(latest_version="1.0.1")
|
||||
)
|
||||
|
||||
update = await is_version_update_available(
|
||||
version_update_notifier, current_version="invalid-version"
|
||||
)
|
||||
|
||||
assert update is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cause", "expected_message_substring"),
|
||||
[
|
||||
(VersionUpdateGatewayCause.TOO_MANY_REQUESTS, "Rate limit exceeded"),
|
||||
(VersionUpdateGatewayCause.INVALID_RESPONSE, "invalid response"),
|
||||
(
|
||||
VersionUpdateGatewayCause.NOT_FOUND,
|
||||
"Unable to fetch the releases. Please check your permissions.",
|
||||
),
|
||||
(VersionUpdateGatewayCause.ERROR_RESPONSE, "Unexpected response"),
|
||||
(VersionUpdateGatewayCause.REQUEST_FAILED, "Network error"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_version_update_error(
|
||||
cause: VersionUpdateGatewayCause, expected_message_substring: str
|
||||
) -> None:
|
||||
version_update_notifier = FakeVersionUpdateGateway(
|
||||
error=VersionUpdateGatewayError(cause=cause)
|
||||
)
|
||||
|
||||
with pytest.raises(VersionUpdateError) as excinfo:
|
||||
await is_version_update_available(
|
||||
version_update_notifier, current_version="1.0.0"
|
||||
)
|
||||
|
||||
assert expected_message_substring in str(excinfo.value)
|
||||
45
vibe-acp.spec
Normal file
45
vibe-acp.spec
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
|
||||
|
||||
a = Analysis(
|
||||
['vibe/acp/entrypoint.py'],
|
||||
pathex=[],
|
||||
binaries=[],
|
||||
datas=[
|
||||
# By default, pyinstaller doesn't include the .md files
|
||||
('vibe/core/prompts/*.md', 'vibe/core/prompts'),
|
||||
('vibe/core/tools/builtins/prompts/*.md', 'vibe/core/tools/builtins/prompts'),
|
||||
# This is necessary because tools are dynamically called in vibe, meaning there is no static reference to those files
|
||||
('vibe/core/tools/builtins/*.py', 'vibe/core/tools/builtins'),
|
||||
('vibe/acp/tools/builtins/*.py', 'vibe/acp/tools/builtins'),
|
||||
],
|
||||
hiddenimports=[],
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
noarchive=False,
|
||||
optimize=0,
|
||||
)
|
||||
pyz = PYZ(a.pure)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.datas,
|
||||
[],
|
||||
name='vibe-acp',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=True,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
5
vibe/__init__.py
Normal file
5
vibe/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
VIBE_ROOT = Path(__file__).parent
|
||||
0
vibe/acp/__init__.py
Normal file
0
vibe/acp/__init__.py
Normal file
441
vibe/acp/acp_agent.py
Normal file
441
vibe/acp/acp_agent.py
Normal file
@@ -0,0 +1,441 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Any, cast, override
|
||||
|
||||
from acp import (
|
||||
PROTOCOL_VERSION,
|
||||
Agent as AcpAgent,
|
||||
AgentSideConnection,
|
||||
AuthenticateRequest,
|
||||
CancelNotification,
|
||||
InitializeRequest,
|
||||
InitializeResponse,
|
||||
LoadSessionRequest,
|
||||
NewSessionRequest,
|
||||
NewSessionResponse,
|
||||
PromptRequest,
|
||||
PromptResponse,
|
||||
RequestError,
|
||||
RequestPermissionRequest,
|
||||
SessionNotification,
|
||||
SetSessionModelRequest,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeRequest,
|
||||
SetSessionModeResponse,
|
||||
stdio_streams,
|
||||
)
|
||||
from acp.helpers import ContentBlock, SessionUpdate
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AgentMessageChunk,
|
||||
AllowedOutcome,
|
||||
AuthenticateResponse,
|
||||
AuthMethod,
|
||||
Implementation,
|
||||
ModelInfo,
|
||||
PromptCapabilities,
|
||||
SessionModelState,
|
||||
SessionModeState,
|
||||
TextContentBlock,
|
||||
TextResourceContents,
|
||||
ToolCall,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import BaseAcpTool
|
||||
from vibe.acp.tools.session_update import (
|
||||
tool_call_session_update,
|
||||
tool_result_session_update,
|
||||
)
|
||||
from vibe.acp.utils import TOOL_OPTIONS, ToolOption, VibeSessionMode
|
||||
from vibe.core import __version__
|
||||
from vibe.core.agent import Agent as VibeAgent
|
||||
from vibe.core.autocompletion.path_prompt_adapter import render_path_prompt
|
||||
from vibe.core.config import MissingAPIKeyError, VibeConfig, load_api_keys_from_env
|
||||
from vibe.core.types import (
|
||||
AssistantEvent,
|
||||
AsyncApprovalCallback,
|
||||
ToolCallEvent,
|
||||
ToolResultEvent,
|
||||
)
|
||||
from vibe.core.utils import CancellationReason, get_user_cancellation_message
|
||||
|
||||
|
||||
class AcpSession(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
id: str
|
||||
agent: VibeAgent
|
||||
mode_id: VibeSessionMode = VibeSessionMode.APPROVAL_REQUIRED
|
||||
task: asyncio.Task[None] | None = None
|
||||
|
||||
|
||||
class VibeAcpAgent(AcpAgent):
|
||||
def __init__(self, connection: AgentSideConnection) -> None:
|
||||
self.sessions: dict[str, AcpSession] = {}
|
||||
self.connection = connection
|
||||
self.client_capabilities = None
|
||||
|
||||
@override
|
||||
async def initialize(self, params: InitializeRequest) -> InitializeResponse:
|
||||
self.client_capabilities = params.clientCapabilities
|
||||
|
||||
# The ACP Agent process can be launched in 3 different ways, depending on installation
|
||||
# - dev mode: `uv run vibe-acp`, ran from the project root
|
||||
# - uv tool install: `vibe-acp`, similar to dev mode, but uv takes care of path resolution
|
||||
# - bundled binary: `./vibe-acp` from binary location
|
||||
# The 2 first modes are working similarly, under the hood uv runs `/some/python /my/entrypoint.py``
|
||||
# The last mode is quite different as our bundler also includes the python install.
|
||||
# So sys.executable is already /path/to/binary/vibe-acp.
|
||||
# For this reason, we make a distinction in the way we call the setup command
|
||||
command = sys.executable
|
||||
if "python" not in Path(command).name:
|
||||
# It's the case for bundled binaries, we don't need any other arguments
|
||||
args = ["--setup"]
|
||||
else:
|
||||
script_name = sys.argv[0]
|
||||
args = [script_name, "--setup"]
|
||||
|
||||
auth_methods = [
|
||||
AuthMethod(
|
||||
id="vibe-setup",
|
||||
name="Register your API Key",
|
||||
description="Register your API Key inside Mistral Vibe",
|
||||
field_meta={
|
||||
"terminal-auth": {
|
||||
"command": command,
|
||||
"args": args,
|
||||
"label": "Mistral Vibe Setup",
|
||||
}
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
response = InitializeResponse(
|
||||
agentCapabilities=AgentCapabilities(
|
||||
loadSession=False,
|
||||
promptCapabilities=PromptCapabilities(
|
||||
audio=False, embeddedContext=True, image=False
|
||||
),
|
||||
),
|
||||
protocolVersion=PROTOCOL_VERSION,
|
||||
agentInfo=Implementation(
|
||||
name="@mistralai/mistral-vibe",
|
||||
title="Mistral Vibe",
|
||||
version=__version__,
|
||||
),
|
||||
authMethods=auth_methods,
|
||||
)
|
||||
return response
|
||||
|
||||
@override
|
||||
async def authenticate(
|
||||
self, params: AuthenticateRequest
|
||||
) -> AuthenticateResponse | None:
|
||||
raise NotImplementedError("Not implemented yet")
|
||||
|
||||
@override
|
||||
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
|
||||
capability_disabled_tools = self._get_disabled_tools_from_capabilities()
|
||||
load_api_keys_from_env()
|
||||
try:
|
||||
config = VibeConfig.load(
|
||||
workdir=Path(params.cwd),
|
||||
tool_paths=[str(VIBE_ROOT / "acp" / "tools" / "builtins")],
|
||||
disabled_tools=capability_disabled_tools,
|
||||
)
|
||||
except MissingAPIKeyError as e:
|
||||
raise RequestError.auth_required({
|
||||
"message": "You must be authenticated before creating a new session"
|
||||
}) from e
|
||||
|
||||
agent = VibeAgent(config=config, auto_approve=False, enable_streaming=True)
|
||||
# NOTE: For now, we pin session.id to agent.session_id right after init time.
|
||||
# We should just use agent.session_id everywhere, but it can still change during
|
||||
# session lifetime (e.g. agent.compact is called).
|
||||
# We should refactor agent.session_id to make it immutable in ACP context.
|
||||
session = AcpSession(id=agent.session_id, agent=agent)
|
||||
self.sessions[session.id] = session
|
||||
|
||||
if not agent.auto_approve:
|
||||
agent.set_approval_callback(
|
||||
self._create_approval_callback(agent.session_id)
|
||||
)
|
||||
|
||||
response = NewSessionResponse(
|
||||
sessionId=agent.session_id,
|
||||
models=SessionModelState(
|
||||
currentModelId=agent.config.active_model,
|
||||
availableModels=[
|
||||
ModelInfo(modelId=model.alias, name=model.alias)
|
||||
for model in agent.config.models
|
||||
],
|
||||
),
|
||||
modes=SessionModeState(
|
||||
currentModeId=session.mode_id,
|
||||
availableModes=VibeSessionMode.get_all_acp_session_modes(),
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
def _get_disabled_tools_from_capabilities(self) -> list[str]:
|
||||
if not self.client_capabilities:
|
||||
return []
|
||||
|
||||
disabled: list[str] = []
|
||||
|
||||
if not self.client_capabilities.terminal:
|
||||
disabled.append("bash")
|
||||
|
||||
if fs := self.client_capabilities.fs:
|
||||
if not fs.readTextFile:
|
||||
disabled.append("read_file")
|
||||
if not fs.writeTextFile:
|
||||
disabled.append("write_file")
|
||||
disabled.append("search_replace")
|
||||
|
||||
return disabled
|
||||
|
||||
def _create_approval_callback(self, session_id: str) -> AsyncApprovalCallback:
|
||||
|
||||
async def approval_callback(
|
||||
tool_name: str, args: dict[str, Any], tool_call_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
# Create the tool call update
|
||||
tool_call = ToolCall(toolCallId=tool_call_id)
|
||||
|
||||
# Request permission from the user
|
||||
request = RequestPermissionRequest(
|
||||
sessionId=session_id, toolCall=tool_call, options=TOOL_OPTIONS
|
||||
)
|
||||
|
||||
response = await self.connection.requestPermission(request)
|
||||
|
||||
# Parse the response using isinstance for proper type narrowing
|
||||
if response.outcome.outcome == "selected":
|
||||
outcome = cast(AllowedOutcome, response.outcome)
|
||||
return self._handle_permission_selection(outcome.optionId)
|
||||
else:
|
||||
return (
|
||||
"n",
|
||||
str(
|
||||
get_user_cancellation_message(
|
||||
CancellationReason.OPERATION_CANCELLED
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return approval_callback
|
||||
|
||||
@staticmethod
|
||||
def _handle_permission_selection(option_id: str) -> tuple[str, str | None]:
|
||||
match option_id:
|
||||
case ToolOption.ALLOW_ONCE:
|
||||
return ("y", None)
|
||||
case ToolOption.ALLOW_ALWAYS:
|
||||
return ("a", None)
|
||||
case ToolOption.REJECT_ONCE:
|
||||
return ("n", "User rejected the tool call, provide an alternative plan")
|
||||
case _:
|
||||
return ("n", f"Unknown option: {option_id}")
|
||||
|
||||
def _get_session(self, session_id: str) -> AcpSession:
|
||||
if session_id not in self.sessions:
|
||||
raise RequestError.invalid_params({"session": "Not found"})
|
||||
return self.sessions[session_id]
|
||||
|
||||
@override
|
||||
async def loadSession(self, params: LoadSessionRequest) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@override
|
||||
async def setSessionMode(
|
||||
self, params: SetSessionModeRequest
|
||||
) -> SetSessionModeResponse | None:
|
||||
session = self._get_session(params.sessionId)
|
||||
|
||||
if not VibeSessionMode.is_valid(params.modeId):
|
||||
return None
|
||||
|
||||
session.mode_id = VibeSessionMode(params.modeId)
|
||||
session.agent.auto_approve = params.modeId == VibeSessionMode.AUTO_APPROVE
|
||||
|
||||
return SetSessionModeResponse()
|
||||
|
||||
@override
|
||||
async def setSessionModel(
|
||||
self, params: SetSessionModelRequest
|
||||
) -> SetSessionModelResponse | None:
|
||||
session = self._get_session(params.sessionId)
|
||||
|
||||
model_aliases = [model.alias for model in session.agent.config.models]
|
||||
if params.modelId not in model_aliases:
|
||||
return None
|
||||
|
||||
VibeConfig.save_updates({"active_model": params.modelId})
|
||||
|
||||
new_config = VibeConfig.load(
|
||||
workdir=session.agent.config.workdir,
|
||||
tool_paths=session.agent.config.tool_paths,
|
||||
disabled_tools=self._get_disabled_tools_from_capabilities(),
|
||||
)
|
||||
|
||||
await session.agent.reload_with_initial_messages(config=new_config)
|
||||
|
||||
return SetSessionModelResponse()
|
||||
|
||||
@override
|
||||
async def prompt(self, params: PromptRequest) -> PromptResponse:
|
||||
session = self._get_session(params.sessionId)
|
||||
|
||||
if session.task is not None:
|
||||
raise RuntimeError(
|
||||
"Concurrent prompts are not supported yet, wait for agent to finish"
|
||||
)
|
||||
|
||||
text_prompt = self._build_text_prompt(params.prompt)
|
||||
|
||||
async def agent_task() -> None:
|
||||
async for update in self._run_agent_loop(session, text_prompt):
|
||||
await self.connection.sessionUpdate(
|
||||
SessionNotification(sessionId=session.id, update=update)
|
||||
)
|
||||
|
||||
try:
|
||||
session.task = asyncio.create_task(agent_task())
|
||||
await session.task
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return PromptResponse(stopReason="cancelled")
|
||||
|
||||
except Exception as e:
|
||||
await self.connection.sessionUpdate(
|
||||
SessionNotification(
|
||||
sessionId=params.sessionId,
|
||||
update=AgentMessageChunk(
|
||||
sessionUpdate="agent_message_chunk",
|
||||
content=TextContentBlock(type="text", text=f"Error: {e!s}"),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return PromptResponse(stopReason="refusal")
|
||||
|
||||
finally:
|
||||
session.task = None
|
||||
|
||||
return PromptResponse(stopReason="end_turn")
|
||||
|
||||
def _build_text_prompt(self, acp_prompt: list[ContentBlock]) -> str:
|
||||
text_prompt = ""
|
||||
for block in acp_prompt:
|
||||
separator = "\n\n" if text_prompt else ""
|
||||
match block.type:
|
||||
# NOTE: ACP supports annotations, but we don't use them here yet.
|
||||
case "text":
|
||||
text_prompt = f"{text_prompt}{separator}{block.text}"
|
||||
case "resource":
|
||||
block_content = (
|
||||
block.resource.text
|
||||
if isinstance(block.resource, TextResourceContents)
|
||||
else block.resource.blob
|
||||
)
|
||||
fields = {"path": block.resource.uri, "content": block_content}
|
||||
parts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in fields.items()
|
||||
if v is not None and (v or isinstance(v, (int, float)))
|
||||
]
|
||||
block_prompt = "\n".join(parts)
|
||||
text_prompt = f"{text_prompt}{separator}{block_prompt}"
|
||||
case "resource_link":
|
||||
# NOTE: we currently keep more information than just the URI
|
||||
# making it more detailed than the output of the read_file tool.
|
||||
# This is OK, but might be worth testing how it affect performance.
|
||||
fields = {
|
||||
"uri": block.uri,
|
||||
"name": block.name,
|
||||
"title": block.title,
|
||||
"description": block.description,
|
||||
"mimeType": block.mimeType,
|
||||
"size": block.size,
|
||||
}
|
||||
parts = [
|
||||
f"{k}: {v}"
|
||||
for k, v in fields.items()
|
||||
if v is not None and (v or isinstance(v, (int, float)))
|
||||
]
|
||||
block_prompt = "\n".join(parts)
|
||||
text_prompt = f"{text_prompt}{separator}{block_prompt}"
|
||||
case _:
|
||||
raise ValueError(f"Unsupported content block type: {block.type}")
|
||||
return text_prompt
|
||||
|
||||
async def _run_agent_loop(
|
||||
self, session: AcpSession, prompt: str
|
||||
) -> AsyncGenerator[SessionUpdate]:
|
||||
rendered_prompt = render_path_prompt(
|
||||
prompt, base_dir=session.agent.config.effective_workdir
|
||||
)
|
||||
async for event in session.agent.act(rendered_prompt):
|
||||
if isinstance(event, AssistantEvent):
|
||||
yield AgentMessageChunk(
|
||||
sessionUpdate="agent_message_chunk",
|
||||
content=TextContentBlock(type="text", text=event.content),
|
||||
)
|
||||
|
||||
elif isinstance(event, ToolCallEvent):
|
||||
if issubclass(event.tool_class, BaseAcpTool):
|
||||
event.tool_class.update_tool_state(
|
||||
tool_manager=session.agent.tool_manager,
|
||||
connection=self.connection,
|
||||
session_id=session.id,
|
||||
tool_call_id=event.tool_call_id,
|
||||
)
|
||||
|
||||
session_update = tool_call_session_update(event)
|
||||
if session_update:
|
||||
yield session_update
|
||||
|
||||
elif isinstance(event, ToolResultEvent):
|
||||
session_update = tool_result_session_update(event)
|
||||
if session_update:
|
||||
yield session_update
|
||||
|
||||
@override
|
||||
async def cancel(self, params: CancelNotification) -> None:
|
||||
session = self._get_session(params.sessionId)
|
||||
if session.task and not session.task.done():
|
||||
session.task.cancel()
|
||||
session.task = None
|
||||
|
||||
@override
|
||||
async def extMethod(self, method: str, params: dict) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
@override
|
||||
async def extNotification(self, method: str, params: dict) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def _run_acp_server() -> None:
|
||||
reader, writer = await stdio_streams()
|
||||
|
||||
AgentSideConnection(lambda connection: VibeAcpAgent(connection), writer, reader)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
def run_acp_server() -> None:
|
||||
try:
|
||||
asyncio.run(_run_acp_server())
|
||||
except KeyboardInterrupt:
|
||||
# This is expected when the server is terminated
|
||||
pass
|
||||
except Exception as e:
|
||||
# Log any unexpected errors
|
||||
print(f"ACP Agent Server error: {e}", file=sys.stderr)
|
||||
raise
|
||||
37
vibe/acp/entrypoint.py
Normal file
37
vibe/acp/entrypoint.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
|
||||
from vibe.acp.acp_agent import run_acp_server
|
||||
from vibe.setup.onboarding import run_onboarding
|
||||
|
||||
# Configure line buffering for subprocess communication
|
||||
sys.stdout.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
|
||||
sys.stderr.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
|
||||
sys.stdin.reconfigure(line_buffering=True) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Arguments:
|
||||
setup: bool
|
||||
|
||||
|
||||
def parse_arguments() -> Arguments:
|
||||
parser = argparse.ArgumentParser(description="Run Mistral Vibe in ACP mode")
|
||||
parser.add_argument("--setup", action="store_true", help="Setup API key and exit")
|
||||
args = parser.parse_args()
|
||||
return Arguments(setup=args.setup)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_arguments()
|
||||
if args.setup:
|
||||
run_onboarding()
|
||||
sys.exit(0)
|
||||
run_acp_server()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
vibe/acp/tools/__init__.py
Normal file
0
vibe/acp/tools/__init__.py
Normal file
100
vibe/acp/tools/base.py
Normal file
100
vibe/acp/tools/base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Protocol, cast, runtime_checkable
|
||||
|
||||
from acp import AgentSideConnection, SessionNotification
|
||||
from acp.helpers import SessionUpdate, ToolCallContentVariant
|
||||
from acp.schema import ToolCallProgress
|
||||
from pydantic import Field
|
||||
|
||||
from vibe.core.tools.base import BaseTool, ToolError
|
||||
from vibe.core.tools.manager import ToolManager
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
from vibe.core.utils import logger
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolCallSessionUpdateProtocol(Protocol):
|
||||
@classmethod
|
||||
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolResultSessionUpdateProtocol(Protocol):
|
||||
@classmethod
|
||||
def tool_result_session_update(
|
||||
cls, event: ToolResultEvent
|
||||
) -> SessionUpdate | None: ...
|
||||
|
||||
|
||||
class AcpToolState:
|
||||
connection: AgentSideConnection | None = Field(
|
||||
default=None, description="ACP agent-side connection"
|
||||
)
|
||||
session_id: str | None = Field(default=None, description="Current ACP session ID")
|
||||
tool_call_id: str | None = Field(
|
||||
default=None, description="Current ACP tool call ID"
|
||||
)
|
||||
|
||||
|
||||
class BaseAcpTool[ToolState: AcpToolState](BaseTool):
|
||||
state: ToolState
|
||||
|
||||
@classmethod
|
||||
def get_tool_instance(
|
||||
cls, tool_name: str, tool_manager: ToolManager
|
||||
) -> BaseAcpTool[AcpToolState]:
|
||||
return cast(BaseAcpTool[AcpToolState], tool_manager.get(tool_name))
|
||||
|
||||
@classmethod
|
||||
def update_tool_state(
|
||||
cls,
|
||||
*,
|
||||
tool_manager: ToolManager,
|
||||
connection: AgentSideConnection | None,
|
||||
session_id: str | None,
|
||||
tool_call_id: str | None,
|
||||
) -> None:
|
||||
tool_instance = cls.get_tool_instance(cls.get_name(), tool_manager)
|
||||
tool_instance.state.connection = connection
|
||||
tool_instance.state.session_id = session_id
|
||||
tool_instance.state.tool_call_id = tool_call_id
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _get_tool_state_class(cls) -> type[ToolState]: ...
|
||||
|
||||
def _load_state(self) -> tuple[AgentSideConnection, str, str | None]:
|
||||
if self.state.connection is None:
|
||||
raise ToolError(
|
||||
"Connection not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
if self.state.session_id is None:
|
||||
raise ToolError(
|
||||
"Session ID not available in tool state. This tool can only be used within an ACP session."
|
||||
)
|
||||
|
||||
return self.state.connection, self.state.session_id, self.state.tool_call_id
|
||||
|
||||
async def _send_in_progress_session_update(
|
||||
self, content: list[ToolCallContentVariant] | None = None
|
||||
) -> None:
|
||||
connection, session_id, tool_call_id = self._load_state()
|
||||
if tool_call_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await connection.sessionUpdate(
|
||||
SessionNotification(
|
||||
sessionId=session_id,
|
||||
update=ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=tool_call_id,
|
||||
status="in_progress",
|
||||
content=content,
|
||||
),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update session: {e!r}")
|
||||
144
vibe/acp/tools/builtins/bash.py
Normal file
144
vibe/acp/tools/builtins/bash.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import shlex
|
||||
|
||||
from acp import CreateTerminalRequest, TerminalHandle
|
||||
from acp.schema import (
|
||||
EnvVariable,
|
||||
TerminalToolCallContent,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
WaitForTerminalExitResponse,
|
||||
)
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
|
||||
from vibe.core.tools.base import BaseToolState, ToolError
|
||||
from vibe.core.tools.builtins.bash import Bash as CoreBashTool, BashArgs, BashResult
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
from vibe.core.utils import logger
|
||||
|
||||
|
||||
class AcpBashState(BaseToolState, AcpToolState):
|
||||
pass
|
||||
|
||||
|
||||
class Bash(CoreBashTool, BaseAcpTool[AcpBashState]):
|
||||
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "bash.md"
|
||||
state: AcpBashState
|
||||
|
||||
@classmethod
|
||||
def _get_tool_state_class(cls) -> type[AcpBashState]:
|
||||
return AcpBashState
|
||||
|
||||
async def run(self, args: BashArgs) -> BashResult:
|
||||
connection, session_id, _ = self._load_state()
|
||||
|
||||
timeout = args.timeout or self.config.default_timeout
|
||||
max_bytes = self.config.max_output_bytes
|
||||
env, command, cmd_args = self._parse_command(args.command)
|
||||
|
||||
create_request = CreateTerminalRequest(
|
||||
sessionId=session_id,
|
||||
command=command,
|
||||
args=cmd_args,
|
||||
env=env,
|
||||
cwd=str(self.config.effective_workdir),
|
||||
outputByteLimit=max_bytes,
|
||||
)
|
||||
|
||||
try:
|
||||
terminal_handle = await connection.createTerminal(create_request)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to create terminal: {e!r}") from e
|
||||
|
||||
await self._send_in_progress_session_update([
|
||||
TerminalToolCallContent(type="terminal", terminalId=terminal_handle.id)
|
||||
])
|
||||
|
||||
try:
|
||||
exit_response = await self._wait_for_terminal_exit(
|
||||
terminal_handle, timeout, args.command
|
||||
)
|
||||
|
||||
output_response = await terminal_handle.current_output()
|
||||
|
||||
return self._build_result(
|
||||
command=args.command,
|
||||
stdout=output_response.output,
|
||||
stderr="",
|
||||
returncode=exit_response.exitCode or 0,
|
||||
)
|
||||
|
||||
finally:
|
||||
try:
|
||||
await terminal_handle.release()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release terminal: {e!r}")
|
||||
|
||||
def _parse_command(
|
||||
self, command_str: str
|
||||
) -> tuple[list[EnvVariable], str, list[str]]:
|
||||
parts = shlex.split(command_str)
|
||||
env: list[EnvVariable] = []
|
||||
command: str = ""
|
||||
args: list[str] = []
|
||||
|
||||
for part in parts:
|
||||
if "=" in part and not command:
|
||||
key, value = part.split("=", 1)
|
||||
env.append(EnvVariable(name=key, value=value))
|
||||
elif not command:
|
||||
command = part
|
||||
else:
|
||||
args.append(part)
|
||||
|
||||
return env, command, args
|
||||
|
||||
@classmethod
|
||||
def get_summary(cls, args: BashArgs) -> str:
|
||||
summary = f"{args.command}"
|
||||
if args.timeout:
|
||||
summary += f" (timeout {args.timeout}s)"
|
||||
|
||||
return summary
|
||||
|
||||
async def _wait_for_terminal_exit(
|
||||
self, terminal_handle: TerminalHandle, timeout: int, command: str
|
||||
) -> WaitForTerminalExitResponse:
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
terminal_handle.wait_for_exit(), timeout=timeout
|
||||
)
|
||||
except TimeoutError:
|
||||
try:
|
||||
await terminal_handle.kill()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to kill terminal: {e!r}")
|
||||
|
||||
raise self._build_timeout_error(command, timeout)
|
||||
|
||||
@classmethod
|
||||
def tool_call_session_update(cls, event: ToolCallEvent) -> ToolCallStart:
|
||||
if not isinstance(event.args, BashArgs):
|
||||
raise ValueError(f"Unexpected tool args: {event.args}")
|
||||
|
||||
return ToolCallStart(
|
||||
sessionUpdate="tool_call",
|
||||
title=Bash.get_summary(event.args),
|
||||
content=None,
|
||||
toolCallId=event.tool_call_id,
|
||||
kind="execute",
|
||||
rawInput=event.args.model_dump_json(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool_result_session_update(
|
||||
cls, event: ToolResultEvent
|
||||
) -> ToolCallProgress | None:
|
||||
return ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=event.tool_call_id,
|
||||
status="failed" if event.error else "completed",
|
||||
)
|
||||
58
vibe/acp/tools/builtins/read_file.py
Normal file
58
vibe/acp/tools/builtins/read_file.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import ReadTextFileRequest
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.read_file import (
|
||||
ReadFile as CoreReadFileTool,
|
||||
ReadFileArgs,
|
||||
ReadFileResult,
|
||||
ReadFileState,
|
||||
_ReadResult,
|
||||
)
|
||||
|
||||
ReadFileResult = ReadFileResult
|
||||
|
||||
|
||||
class AcpReadFileState(ReadFileState, AcpToolState):
|
||||
pass
|
||||
|
||||
|
||||
class ReadFile(CoreReadFileTool, BaseAcpTool[AcpReadFileState]):
|
||||
state: AcpReadFileState
|
||||
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "read_file.md"
|
||||
|
||||
@classmethod
|
||||
def _get_tool_state_class(cls) -> type[AcpReadFileState]:
|
||||
return AcpReadFileState
|
||||
|
||||
async def _read_file(self, args: ReadFileArgs, file_path: Path) -> _ReadResult:
|
||||
connection, session_id, _ = self._load_state()
|
||||
|
||||
line = args.offset + 1 if args.offset > 0 else None
|
||||
limit = args.limit
|
||||
|
||||
read_request = ReadTextFileRequest(
|
||||
sessionId=session_id, path=str(file_path), line=line, limit=limit
|
||||
)
|
||||
|
||||
await self._send_in_progress_session_update()
|
||||
|
||||
try:
|
||||
response = await connection.readTextFile(read_request)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error reading {file_path}: {e}") from e
|
||||
|
||||
content_lines = response.content.splitlines(keepends=True)
|
||||
lines_read = len(content_lines)
|
||||
bytes_read = sum(len(line.encode("utf-8")) for line in content_lines)
|
||||
|
||||
was_truncated = args.limit is not None and lines_read >= args.limit
|
||||
|
||||
return _ReadResult(
|
||||
lines=content_lines, bytes_read=bytes_read, was_truncated=was_truncated
|
||||
)
|
||||
132
vibe/acp/tools/builtins/search_replace.py
Normal file
132
vibe/acp/tools/builtins/search_replace.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import ReadTextFileRequest, WriteTextFileRequest
|
||||
from acp.helpers import SessionUpdate
|
||||
from acp.schema import (
|
||||
FileEditToolCallContent,
|
||||
ToolCallLocation,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
)
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.search_replace import (
|
||||
SearchReplace as CoreSearchReplaceTool,
|
||||
SearchReplaceArgs,
|
||||
SearchReplaceResult,
|
||||
SearchReplaceState,
|
||||
)
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
|
||||
|
||||
class AcpSearchReplaceState(SearchReplaceState, AcpToolState):
|
||||
file_backup_content: str | None = None
|
||||
|
||||
|
||||
class SearchReplace(CoreSearchReplaceTool, BaseAcpTool[AcpSearchReplaceState]):
|
||||
state: AcpSearchReplaceState
|
||||
prompt_path = (
|
||||
VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "search_replace.md"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_tool_state_class(cls) -> type[AcpSearchReplaceState]:
|
||||
return AcpSearchReplaceState
|
||||
|
||||
async def _read_file(self, file_path: Path) -> str:
|
||||
connection, session_id, _ = self._load_state()
|
||||
|
||||
read_request = ReadTextFileRequest(sessionId=session_id, path=str(file_path))
|
||||
|
||||
await self._send_in_progress_session_update()
|
||||
|
||||
try:
|
||||
response = await connection.readTextFile(read_request)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Unexpected error reading {file_path}: {e}") from e
|
||||
|
||||
self.state.file_backup_content = response.content
|
||||
return response.content
|
||||
|
||||
async def _backup_file(self, file_path: Path) -> None:
|
||||
if self.state.file_backup_content is None:
|
||||
return
|
||||
|
||||
await self._write_file(
|
||||
file_path.with_suffix(file_path.suffix + ".bak"),
|
||||
self.state.file_backup_content,
|
||||
)
|
||||
|
||||
async def _write_file(self, file_path: Path, content: str) -> None:
|
||||
connection, session_id, _ = self._load_state()
|
||||
|
||||
write_request = WriteTextFileRequest(
|
||||
sessionId=session_id, path=str(file_path), content=content
|
||||
)
|
||||
|
||||
try:
|
||||
await connection.writeTextFile(write_request)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error writing {file_path}: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
|
||||
args = event.args
|
||||
if not isinstance(args, SearchReplaceArgs):
|
||||
return None
|
||||
|
||||
blocks = cls._parse_search_replace_blocks(args.content)
|
||||
|
||||
return ToolCallStart(
|
||||
sessionUpdate="tool_call",
|
||||
title=cls.get_call_display(event).summary,
|
||||
toolCallId=event.tool_call_id,
|
||||
kind="edit",
|
||||
content=[
|
||||
FileEditToolCallContent(
|
||||
type="diff",
|
||||
path=args.file_path,
|
||||
oldText=block.search,
|
||||
newText=block.replace,
|
||||
)
|
||||
for block in blocks
|
||||
],
|
||||
locations=[ToolCallLocation(path=args.file_path)],
|
||||
rawInput=args.model_dump_json(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
|
||||
if event.error:
|
||||
return ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=event.tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
result = event.result
|
||||
if not isinstance(result, SearchReplaceResult):
|
||||
return None
|
||||
|
||||
blocks = cls._parse_search_replace_blocks(result.content)
|
||||
|
||||
return ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=event.tool_call_id,
|
||||
status="completed",
|
||||
content=[
|
||||
FileEditToolCallContent(
|
||||
type="diff",
|
||||
path=result.file,
|
||||
oldText=block.search,
|
||||
newText=block.replace,
|
||||
)
|
||||
for block in blocks
|
||||
],
|
||||
locations=[ToolCallLocation(path=result.file)],
|
||||
rawOutput=result.model_dump_json(),
|
||||
)
|
||||
65
vibe/acp/tools/builtins/todo.py
Normal file
65
vibe/acp/tools/builtins/todo.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
from acp.helpers import SessionUpdate
|
||||
from acp.schema import AgentPlanUpdate, PlanEntry, PlanEntryPriority, PlanEntryStatus
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
|
||||
from vibe.core.tools.builtins.todo import (
|
||||
Todo as CoreTodoTool,
|
||||
TodoArgs,
|
||||
TodoPriority,
|
||||
TodoResult,
|
||||
TodoState,
|
||||
TodoStatus,
|
||||
)
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
|
||||
TodoArgs = TodoArgs
|
||||
|
||||
|
||||
class AcpTodoState(TodoState, AcpToolState):
|
||||
pass
|
||||
|
||||
|
||||
class Todo(CoreTodoTool, BaseAcpTool[AcpTodoState]):
|
||||
state: AcpTodoState
|
||||
prompt_path = VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "todo.md"
|
||||
|
||||
@classmethod
|
||||
def _get_tool_state_class(cls) -> type[AcpTodoState]:
|
||||
return AcpTodoState
|
||||
|
||||
@classmethod
|
||||
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
|
||||
result = cast(TodoResult, event.result)
|
||||
todos = [todo for todo in result.todos if todo.status != TodoStatus.CANCELLED]
|
||||
matched_status: dict[TodoStatus, PlanEntryStatus] = {
|
||||
TodoStatus.PENDING: "pending",
|
||||
TodoStatus.IN_PROGRESS: "in_progress",
|
||||
TodoStatus.COMPLETED: "completed",
|
||||
}
|
||||
matched_priority: dict[TodoPriority, PlanEntryPriority] = {
|
||||
TodoPriority.LOW: "low",
|
||||
TodoPriority.MEDIUM: "medium",
|
||||
TodoPriority.HIGH: "high",
|
||||
}
|
||||
|
||||
update = AgentPlanUpdate(
|
||||
sessionUpdate="plan",
|
||||
entries=[
|
||||
PlanEntry(
|
||||
content=todo.content,
|
||||
status=matched_status[todo.status],
|
||||
priority=matched_priority[todo.priority],
|
||||
)
|
||||
for todo in todos
|
||||
],
|
||||
)
|
||||
return update
|
||||
98
vibe/acp/tools/builtins/write_file.py
Normal file
98
vibe/acp/tools/builtins/write_file.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from acp import WriteTextFileRequest
|
||||
from acp.helpers import SessionUpdate
|
||||
from acp.schema import (
|
||||
FileEditToolCallContent,
|
||||
ToolCallLocation,
|
||||
ToolCallProgress,
|
||||
ToolCallStart,
|
||||
)
|
||||
|
||||
from vibe import VIBE_ROOT
|
||||
from vibe.acp.tools.base import AcpToolState, BaseAcpTool
|
||||
from vibe.core.tools.base import ToolError
|
||||
from vibe.core.tools.builtins.write_file import (
|
||||
WriteFile as CoreWriteFileTool,
|
||||
WriteFileArgs,
|
||||
WriteFileResult,
|
||||
WriteFileState,
|
||||
)
|
||||
from vibe.core.types import ToolCallEvent, ToolResultEvent
|
||||
|
||||
|
||||
class AcpWriteFileState(WriteFileState, AcpToolState):
|
||||
pass
|
||||
|
||||
|
||||
class WriteFile(CoreWriteFileTool, BaseAcpTool[AcpWriteFileState]):
|
||||
state: AcpWriteFileState
|
||||
prompt_path = (
|
||||
VIBE_ROOT / "core" / "tools" / "builtins" / "prompts" / "write_file.md"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_tool_state_class(cls) -> type[AcpWriteFileState]:
|
||||
return AcpWriteFileState
|
||||
|
||||
async def _write_file(self, args: WriteFileArgs, file_path: Path) -> None:
|
||||
connection, session_id, _ = self._load_state()
|
||||
|
||||
write_request = WriteTextFileRequest(
|
||||
sessionId=session_id, path=str(file_path), content=args.content
|
||||
)
|
||||
|
||||
await self._send_in_progress_session_update()
|
||||
|
||||
try:
|
||||
await connection.writeTextFile(write_request)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Error writing {file_path}: {e}") from e
|
||||
|
||||
@classmethod
|
||||
def tool_call_session_update(cls, event: ToolCallEvent) -> SessionUpdate | None:
|
||||
args = event.args
|
||||
if not isinstance(args, WriteFileArgs):
|
||||
return None
|
||||
|
||||
return ToolCallStart(
|
||||
sessionUpdate="tool_call",
|
||||
title=cls.get_call_display(event).summary,
|
||||
toolCallId=event.tool_call_id,
|
||||
kind="edit",
|
||||
content=[
|
||||
FileEditToolCallContent(
|
||||
type="diff", path=args.path, oldText=None, newText=args.content
|
||||
)
|
||||
],
|
||||
locations=[ToolCallLocation(path=args.path)],
|
||||
rawInput=args.model_dump_json(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tool_result_session_update(cls, event: ToolResultEvent) -> SessionUpdate | None:
|
||||
if event.error:
|
||||
return ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=event.tool_call_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
result = event.result
|
||||
if not isinstance(result, WriteFileResult):
|
||||
return None
|
||||
|
||||
return ToolCallProgress(
|
||||
sessionUpdate="tool_call_update",
|
||||
toolCallId=event.tool_call_id,
|
||||
status="completed",
|
||||
content=[
|
||||
FileEditToolCallContent(
|
||||
type="diff", path=result.path, oldText=None, newText=result.content
|
||||
)
|
||||
],
|
||||
locations=[ToolCallLocation(path=result.path)],
|
||||
rawOutput=result.model_dump_json(),
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user