mirror of
https://github.com/RightNow-AI/openfang.git
synced 2026-04-25 17:25:11 +02:00
Rebased on latest main (f1ca527) after codebase changes. This is a
fresh submission after PR #22 was closed as stale.
## Why This Feature
Enables enterprise GCP deployments using existing service accounts
instead of requiring separate Gemini API keys. Many organizations
already have GCP infrastructure and prefer OAuth-based auth.
## What's New
- VertexAIDriver with full streaming support
- OAuth 2.0 token caching (50 min TTL) with auto-refresh via gcloud
- Auto-detection of project_id from service account JSON
- Security: tokens stored with Zeroizing<String>
- Provider aliases: vertex-ai, vertex, google-vertex
- Compatible with new ContentBlock::provider_metadata field
## Testing
- 6 unit tests passing
- Clippy clean (no warnings)
- End-to-end tested with real GCP service account + gemini-2.0-flash
- Both streaming and non-streaming paths verified
## Usage
export GOOGLE_APPLICATION_CREDENTIALS=/path/to/sa.json
# Set provider=vertex-ai, model=gemini-2.0-flash in config.toml
898 lines
32 KiB
Rust
898 lines
32 KiB
Rust
//! LLM driver implementations.
|
|
//!
|
|
//! Contains drivers for Anthropic Claude, Google Gemini, OpenAI-compatible APIs, and more.
|
|
//! Supports: Anthropic, Gemini, OpenAI, Groq, OpenRouter, DeepSeek, Together,
|
|
//! Mistral, Fireworks, Ollama, vLLM, Chutes.ai, and any OpenAI-compatible endpoint.
|
|
|
|
pub mod anthropic;
|
|
pub mod claude_code;
|
|
pub mod copilot;
|
|
pub mod fallback;
|
|
pub mod gemini;
|
|
pub mod openai;
|
|
pub mod qwen_code;
|
|
pub mod vertex;
|
|
|
|
use crate::llm_driver::{DriverConfig, LlmDriver, LlmError};
|
|
use openfang_types::model_catalog::{
|
|
AI21_BASE_URL, ANTHROPIC_BASE_URL, AZURE_OPENAI_BASE_URL, CEREBRAS_BASE_URL, CHUTES_BASE_URL,
|
|
COHERE_BASE_URL, DEEPSEEK_BASE_URL, FIREWORKS_BASE_URL, GEMINI_BASE_URL, GROQ_BASE_URL,
|
|
HUGGINGFACE_BASE_URL, KIMI_CODING_BASE_URL, LEMONADE_BASE_URL, LMSTUDIO_BASE_URL,
|
|
MINIMAX_BASE_URL, MISTRAL_BASE_URL, MOONSHOT_BASE_URL, NVIDIA_NIM_BASE_URL, OLLAMA_BASE_URL,
|
|
OPENAI_BASE_URL, OPENROUTER_BASE_URL, PERPLEXITY_BASE_URL, QIANFAN_BASE_URL, QWEN_BASE_URL,
|
|
REPLICATE_BASE_URL, SAMBANOVA_BASE_URL, TOGETHER_BASE_URL, VENICE_BASE_URL, VLLM_BASE_URL,
|
|
VOLCENGINE_BASE_URL, VOLCENGINE_CODING_BASE_URL, XAI_BASE_URL, ZAI_BASE_URL,
|
|
ZAI_CODING_BASE_URL, ZHIPU_BASE_URL, ZHIPU_CODING_BASE_URL,
|
|
};
|
|
use std::sync::Arc;
|
|
|
|
/// Provider metadata: base URL and env var name for the API key.
|
|
struct ProviderDefaults {
|
|
base_url: &'static str,
|
|
api_key_env: &'static str,
|
|
/// If true, the API key is required (error if missing).
|
|
key_required: bool,
|
|
}
|
|
|
|
/// Get defaults for known providers.
|
|
fn provider_defaults(provider: &str) -> Option<ProviderDefaults> {
|
|
match provider {
|
|
"groq" => Some(ProviderDefaults {
|
|
base_url: GROQ_BASE_URL,
|
|
api_key_env: "GROQ_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"openrouter" => Some(ProviderDefaults {
|
|
base_url: OPENROUTER_BASE_URL,
|
|
api_key_env: "OPENROUTER_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"deepseek" => Some(ProviderDefaults {
|
|
base_url: DEEPSEEK_BASE_URL,
|
|
api_key_env: "DEEPSEEK_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"together" => Some(ProviderDefaults {
|
|
base_url: TOGETHER_BASE_URL,
|
|
api_key_env: "TOGETHER_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"mistral" => Some(ProviderDefaults {
|
|
base_url: MISTRAL_BASE_URL,
|
|
api_key_env: "MISTRAL_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"fireworks" => Some(ProviderDefaults {
|
|
base_url: FIREWORKS_BASE_URL,
|
|
api_key_env: "FIREWORKS_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"openai" => Some(ProviderDefaults {
|
|
base_url: OPENAI_BASE_URL,
|
|
api_key_env: "OPENAI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"gemini" | "google" => Some(ProviderDefaults {
|
|
base_url: GEMINI_BASE_URL,
|
|
api_key_env: "GEMINI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"ollama" => Some(ProviderDefaults {
|
|
base_url: OLLAMA_BASE_URL,
|
|
api_key_env: "OLLAMA_API_KEY",
|
|
key_required: false,
|
|
}),
|
|
"vllm" => Some(ProviderDefaults {
|
|
base_url: VLLM_BASE_URL,
|
|
api_key_env: "VLLM_API_KEY",
|
|
key_required: false,
|
|
}),
|
|
"lmstudio" => Some(ProviderDefaults {
|
|
base_url: LMSTUDIO_BASE_URL,
|
|
api_key_env: "LMSTUDIO_API_KEY",
|
|
key_required: false,
|
|
}),
|
|
"lemonade" => Some(ProviderDefaults {
|
|
base_url: LEMONADE_BASE_URL,
|
|
api_key_env: "LEMONADE_API_KEY",
|
|
key_required: false,
|
|
}),
|
|
"perplexity" => Some(ProviderDefaults {
|
|
base_url: PERPLEXITY_BASE_URL,
|
|
api_key_env: "PERPLEXITY_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"cohere" => Some(ProviderDefaults {
|
|
base_url: COHERE_BASE_URL,
|
|
api_key_env: "COHERE_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"ai21" => Some(ProviderDefaults {
|
|
base_url: AI21_BASE_URL,
|
|
api_key_env: "AI21_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"cerebras" => Some(ProviderDefaults {
|
|
base_url: CEREBRAS_BASE_URL,
|
|
api_key_env: "CEREBRAS_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"sambanova" => Some(ProviderDefaults {
|
|
base_url: SAMBANOVA_BASE_URL,
|
|
api_key_env: "SAMBANOVA_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"huggingface" => Some(ProviderDefaults {
|
|
base_url: HUGGINGFACE_BASE_URL,
|
|
api_key_env: "HF_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"xai" => Some(ProviderDefaults {
|
|
base_url: XAI_BASE_URL,
|
|
api_key_env: "XAI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"replicate" => Some(ProviderDefaults {
|
|
base_url: REPLICATE_BASE_URL,
|
|
api_key_env: "REPLICATE_API_TOKEN",
|
|
key_required: true,
|
|
}),
|
|
"github-copilot" | "copilot" => Some(ProviderDefaults {
|
|
base_url: copilot::GITHUB_COPILOT_BASE_URL,
|
|
api_key_env: "GITHUB_TOKEN",
|
|
key_required: true,
|
|
}),
|
|
"codex" | "openai-codex" => Some(ProviderDefaults {
|
|
base_url: OPENAI_BASE_URL,
|
|
api_key_env: "OPENAI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"claude-code" => Some(ProviderDefaults {
|
|
base_url: "",
|
|
api_key_env: "",
|
|
key_required: false,
|
|
}),
|
|
"moonshot" | "kimi" | "kimi2" => Some(ProviderDefaults {
|
|
base_url: MOONSHOT_BASE_URL,
|
|
api_key_env: "MOONSHOT_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"kimi_coding" => Some(ProviderDefaults {
|
|
base_url: KIMI_CODING_BASE_URL,
|
|
api_key_env: "KIMI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"qwen" | "dashscope" | "model_studio" => Some(ProviderDefaults {
|
|
base_url: QWEN_BASE_URL,
|
|
api_key_env: "DASHSCOPE_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"minimax" => Some(ProviderDefaults {
|
|
base_url: MINIMAX_BASE_URL,
|
|
api_key_env: "MINIMAX_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"zhipu" | "glm" => Some(ProviderDefaults {
|
|
base_url: ZHIPU_BASE_URL,
|
|
api_key_env: "ZHIPU_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"zhipu_coding" | "codegeex" => Some(ProviderDefaults {
|
|
base_url: ZHIPU_CODING_BASE_URL,
|
|
api_key_env: "ZHIPU_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"zai" | "z.ai" => Some(ProviderDefaults {
|
|
base_url: ZAI_BASE_URL,
|
|
api_key_env: "ZHIPU_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"zai_coding" => Some(ProviderDefaults {
|
|
base_url: ZAI_CODING_BASE_URL,
|
|
api_key_env: "ZHIPU_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"qianfan" | "baidu" => Some(ProviderDefaults {
|
|
base_url: QIANFAN_BASE_URL,
|
|
api_key_env: "QIANFAN_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"volcengine" | "doubao" => Some(ProviderDefaults {
|
|
base_url: VOLCENGINE_BASE_URL,
|
|
api_key_env: "VOLCENGINE_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"volcengine_coding" => Some(ProviderDefaults {
|
|
base_url: VOLCENGINE_CODING_BASE_URL,
|
|
api_key_env: "VOLCENGINE_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"chutes" => Some(ProviderDefaults {
|
|
base_url: CHUTES_BASE_URL,
|
|
api_key_env: "CHUTES_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"venice" => Some(ProviderDefaults {
|
|
base_url: VENICE_BASE_URL,
|
|
api_key_env: "VENICE_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"nvidia" | "nvidia-nim" => Some(ProviderDefaults {
|
|
base_url: NVIDIA_NIM_BASE_URL,
|
|
api_key_env: "NVIDIA_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"azure" | "azure-openai" => Some(ProviderDefaults {
|
|
base_url: AZURE_OPENAI_BASE_URL,
|
|
api_key_env: "AZURE_OPENAI_API_KEY",
|
|
key_required: true,
|
|
}),
|
|
"vertex-ai" | "vertex" | "google-vertex" => Some(ProviderDefaults {
|
|
// Vertex AI uses OAuth, not API keys - base_url is per-project
|
|
base_url: "https://us-central1-aiplatform.googleapis.com",
|
|
api_key_env: "GOOGLE_APPLICATION_CREDENTIALS",
|
|
key_required: false, // Uses OAuth service account, not API key
|
|
}),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
/// Create an LLM driver based on provider name and configuration.
|
|
///
|
|
/// Supported providers:
|
|
/// - `anthropic` — Anthropic Claude (Messages API)
|
|
/// - `openai` — OpenAI GPT models
|
|
/// - `groq` — Groq (ultra-fast inference)
|
|
/// - `openrouter` — OpenRouter (multi-model gateway)
|
|
/// - `deepseek` — DeepSeek
|
|
/// - `together` — Together AI
|
|
/// - `mistral` — Mistral AI
|
|
/// - `fireworks` — Fireworks AI
|
|
/// - `ollama` — Ollama (local)
|
|
/// - `vllm` — vLLM (local)
|
|
/// - `lmstudio` — LM Studio (local)
|
|
/// - `perplexity` — Perplexity AI (search-augmented)
|
|
/// - `cohere` — Cohere (Command R)
|
|
/// - `ai21` — AI21 Labs (Jamba)
|
|
/// - `cerebras` — Cerebras (ultra-fast inference)
|
|
/// - `sambanova` — SambaNova
|
|
/// - `huggingface` — Hugging Face Inference API
|
|
/// - `xai` — xAI (Grok)
|
|
/// - `replicate` — Replicate
|
|
/// - `chutes` — Chutes.ai (serverless open-source model inference)
|
|
/// - Any custom provider with `base_url` set uses OpenAI-compatible format
|
|
pub fn create_driver(config: &DriverConfig) -> Result<Arc<dyn LlmDriver>, LlmError> {
|
|
let provider = config.provider.as_str();
|
|
|
|
// Anthropic uses a different API format — special case
|
|
if provider == "anthropic" {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey("Set ANTHROPIC_API_KEY environment variable".to_string())
|
|
})?;
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| ANTHROPIC_BASE_URL.to_string());
|
|
return Ok(Arc::new(anthropic::AnthropicDriver::new(api_key, base_url)));
|
|
}
|
|
|
|
// Gemini uses a different API format — special case
|
|
if provider == "gemini" || provider == "google" {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("GEMINI_API_KEY").ok())
|
|
.or_else(|| std::env::var("GOOGLE_API_KEY").ok())
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey(
|
|
"Set GEMINI_API_KEY or GOOGLE_API_KEY environment variable".to_string(),
|
|
)
|
|
})?;
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| GEMINI_BASE_URL.to_string());
|
|
return Ok(Arc::new(gemini::GeminiDriver::new(api_key, base_url)));
|
|
}
|
|
|
|
// Codex — reuses OpenAI driver with credential sync from Codex CLI
|
|
if provider == "codex" || provider == "openai-codex" {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
|
|
.or_else(crate::model_catalog::read_codex_credential)
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey("Set OPENAI_API_KEY or install Codex CLI".to_string())
|
|
})?;
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| OPENAI_BASE_URL.to_string());
|
|
return Ok(Arc::new(openai::OpenAIDriver::new(api_key, base_url)));
|
|
}
|
|
|
|
// Claude Code CLI — subprocess-based, no API key needed
|
|
if provider == "claude-code" {
|
|
let cli_path = config.base_url.clone();
|
|
return Ok(Arc::new(claude_code::ClaudeCodeDriver::new(
|
|
cli_path,
|
|
config.skip_permissions,
|
|
)));
|
|
}
|
|
|
|
// Qwen Code CLI — subprocess-based, uses Qwen OAuth (free tier)
|
|
if provider == "qwen-code" {
|
|
let cli_path = config.base_url.clone();
|
|
return Ok(Arc::new(qwen_code::QwenCodeDriver::new(
|
|
cli_path,
|
|
config.skip_permissions,
|
|
)));
|
|
}
|
|
|
|
// GitHub Copilot — wraps OpenAI-compatible driver with automatic token exchange.
|
|
// The CopilotDriver exchanges the GitHub PAT for a Copilot API token on demand,
|
|
// caches it, and refreshes when expired.
|
|
if provider == "github-copilot" || provider == "copilot" {
|
|
let github_token = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("GITHUB_TOKEN").ok())
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey(
|
|
"Set GITHUB_TOKEN environment variable for GitHub Copilot".to_string(),
|
|
)
|
|
})?;
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| copilot::GITHUB_COPILOT_BASE_URL.to_string());
|
|
return Ok(Arc::new(copilot::CopilotDriver::new(
|
|
github_token,
|
|
base_url,
|
|
)));
|
|
}
|
|
|
|
// Azure OpenAI — deployment-based URL with `api-key` header
|
|
if provider == "azure" || provider == "azure-openai" {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("AZURE_OPENAI_API_KEY").ok())
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey(
|
|
"Set AZURE_OPENAI_API_KEY environment variable for Azure OpenAI".to_string(),
|
|
)
|
|
})?;
|
|
let base_url = config.base_url.clone().ok_or_else(|| LlmError::Api {
|
|
status: 0,
|
|
message: "Azure OpenAI requires base_url — set it to \
|
|
https://{resource}.openai.azure.com/openai/deployments"
|
|
.to_string(),
|
|
})?;
|
|
return Ok(Arc::new(openai::OpenAIDriver::new_azure(api_key, base_url)));
|
|
}
|
|
|
|
// Vertex AI — uses Google Cloud OAuth with service account credentials.
|
|
// Requires GOOGLE_APPLICATION_CREDENTIALS env var pointing to service account JSON,
|
|
// and the service account must be activated via gcloud CLI.
|
|
if provider == "vertex-ai" || provider == "vertex" || provider == "google-vertex" {
|
|
// Get project_id from environment or service account JSON
|
|
let project_id = std::env::var("GOOGLE_CLOUD_PROJECT")
|
|
.or_else(|_| std::env::var("GCLOUD_PROJECT"))
|
|
.or_else(|_| std::env::var("GCP_PROJECT"))
|
|
.or_else(|_| {
|
|
// Try to read from service account JSON
|
|
if let Ok(creds_path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") {
|
|
if let Ok(contents) = std::fs::read_to_string(&creds_path) {
|
|
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&contents) {
|
|
if let Some(proj) = json.get("project_id").and_then(|v| v.as_str()) {
|
|
return Ok(proj.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
Err(std::env::VarError::NotPresent)
|
|
})
|
|
.map_err(|_| {
|
|
LlmError::MissingApiKey(
|
|
"Set GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_CLOUD_PROJECT for Vertex AI"
|
|
.to_string(),
|
|
)
|
|
})?;
|
|
let region = std::env::var("GOOGLE_CLOUD_REGION")
|
|
.or_else(|_| std::env::var("VERTEX_AI_REGION"))
|
|
.unwrap_or_else(|_| "us-central1".to_string());
|
|
return Ok(Arc::new(vertex::VertexAIDriver::new(project_id, region)));
|
|
}
|
|
|
|
// Kimi for Code — Anthropic-compatible endpoint
|
|
if provider == "kimi_coding" {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var("KIMI_API_KEY").ok())
|
|
.ok_or_else(|| {
|
|
LlmError::MissingApiKey("Set KIMI_API_KEY environment variable".to_string())
|
|
})?;
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| KIMI_CODING_BASE_URL.to_string());
|
|
return Ok(Arc::new(anthropic::AnthropicDriver::new(api_key, base_url)));
|
|
}
|
|
|
|
// All other providers use OpenAI-compatible format
|
|
if let Some(defaults) = provider_defaults(provider) {
|
|
let api_key = config
|
|
.api_key
|
|
.clone()
|
|
.or_else(|| std::env::var(defaults.api_key_env).ok())
|
|
.unwrap_or_default();
|
|
|
|
if defaults.key_required && api_key.is_empty() {
|
|
return Err(LlmError::MissingApiKey(format!(
|
|
"Set {} environment variable for provider '{}'",
|
|
defaults.api_key_env, provider
|
|
)));
|
|
}
|
|
|
|
let base_url = config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| defaults.base_url.to_string());
|
|
|
|
return Ok(Arc::new(openai::OpenAIDriver::new(api_key, base_url)));
|
|
}
|
|
|
|
// Unknown provider — if base_url is set, treat as custom OpenAI-compatible.
|
|
// For custom providers, try the convention {PROVIDER_UPPER}_API_KEY as env var
|
|
// when no explicit api_key was passed. This lets users just set e.g. NVIDIA_API_KEY
|
|
// in their environment and use provider = "nvidia" without extra config.
|
|
if let Some(ref base_url) = config.base_url {
|
|
let api_key = config.api_key.clone().unwrap_or_else(|| {
|
|
let env_var = format!("{}_API_KEY", provider.to_uppercase().replace('-', "_"));
|
|
std::env::var(&env_var).unwrap_or_default()
|
|
});
|
|
return Ok(Arc::new(openai::OpenAIDriver::new(
|
|
api_key,
|
|
base_url.clone(),
|
|
)));
|
|
}
|
|
|
|
// No base_url either — last resort: check if the user set an API key env var
|
|
// using the convention {PROVIDER_UPPER}_API_KEY. If found, use OpenAI-compatible
|
|
// driver with a default base URL derived from common patterns.
|
|
{
|
|
let env_var = format!("{}_API_KEY", provider.to_uppercase().replace('-', "_"));
|
|
if let Ok(api_key) = std::env::var(&env_var) {
|
|
if !api_key.is_empty() {
|
|
return Err(LlmError::Api {
|
|
status: 0,
|
|
message: format!(
|
|
"Provider '{}' has API key ({} is set) but no base_url configured. \
|
|
Add base_url to your [default_model] config or set it in [provider_urls].",
|
|
provider, env_var
|
|
),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
Err(LlmError::Api {
|
|
status: 0,
|
|
message: format!(
|
|
"Unknown provider '{}'. Supported: anthropic, gemini, openai, azure, groq, openrouter, \
|
|
deepseek, together, mistral, fireworks, ollama, vllm, lmstudio, perplexity, \
|
|
cohere, ai21, cerebras, sambanova, huggingface, xai, replicate, github-copilot, \
|
|
chutes, venice, nvidia, codex, claude-code. Or set base_url for a custom OpenAI-compatible endpoint.",
|
|
provider
|
|
),
|
|
})
|
|
}
|
|
|
|
/// Detect the first available provider by scanning environment variables.
|
|
///
|
|
/// Returns `(provider, model, api_key_env)` for the first provider that has a
|
|
/// configured API key, checked in a user-friendly priority order.
|
|
pub fn detect_available_provider() -> Option<(&'static str, &'static str, &'static str)> {
|
|
// Priority: popular cloud providers first, then niche, then local
|
|
const PROBE_ORDER: &[(&str, &str, &str)] = &[
|
|
("openai", "gpt-4o", "OPENAI_API_KEY"),
|
|
("anthropic", "claude-sonnet-4-20250514", "ANTHROPIC_API_KEY"),
|
|
("gemini", "gemini-2.5-flash", "GEMINI_API_KEY"),
|
|
("groq", "llama-3.3-70b-versatile", "GROQ_API_KEY"),
|
|
("deepseek", "deepseek-chat", "DEEPSEEK_API_KEY"),
|
|
(
|
|
"openrouter",
|
|
"openrouter/google/gemini-2.5-flash",
|
|
"OPENROUTER_API_KEY",
|
|
),
|
|
("mistral", "mistral-large-latest", "MISTRAL_API_KEY"),
|
|
(
|
|
"together",
|
|
"meta-llama/Llama-3-70b-chat-hf",
|
|
"TOGETHER_API_KEY",
|
|
),
|
|
(
|
|
"fireworks",
|
|
"accounts/fireworks/models/llama-v3p1-70b-instruct",
|
|
"FIREWORKS_API_KEY",
|
|
),
|
|
("xai", "grok-2", "XAI_API_KEY"),
|
|
(
|
|
"perplexity",
|
|
"llama-3.1-sonar-large-128k-online",
|
|
"PERPLEXITY_API_KEY",
|
|
),
|
|
("cohere", "command-r-plus", "COHERE_API_KEY"),
|
|
];
|
|
for &(provider, model, env_var) in PROBE_ORDER {
|
|
if std::env::var(env_var)
|
|
.ok()
|
|
.filter(|v| !v.is_empty())
|
|
.is_some()
|
|
{
|
|
return Some((provider, model, env_var));
|
|
}
|
|
}
|
|
// Also check GOOGLE_API_KEY as alias for Gemini
|
|
if std::env::var("GOOGLE_API_KEY")
|
|
.ok()
|
|
.filter(|v| !v.is_empty())
|
|
.is_some()
|
|
{
|
|
return Some(("gemini", "gemini-2.5-flash", "GOOGLE_API_KEY"));
|
|
}
|
|
None
|
|
}
|
|
|
|
/// List all known provider names.
|
|
pub fn known_providers() -> &'static [&'static str] {
|
|
&[
|
|
"anthropic",
|
|
"gemini",
|
|
"openai",
|
|
"groq",
|
|
"openrouter",
|
|
"deepseek",
|
|
"together",
|
|
"mistral",
|
|
"fireworks",
|
|
"ollama",
|
|
"vllm",
|
|
"lmstudio",
|
|
"perplexity",
|
|
"cohere",
|
|
"ai21",
|
|
"cerebras",
|
|
"sambanova",
|
|
"huggingface",
|
|
"xai",
|
|
"replicate",
|
|
"github-copilot",
|
|
"moonshot",
|
|
"qwen",
|
|
"minimax",
|
|
"zhipu",
|
|
"zhipu_coding",
|
|
"zai",
|
|
"kimi_coding",
|
|
"qianfan",
|
|
"volcengine",
|
|
"chutes",
|
|
"venice",
|
|
"nvidia",
|
|
"codex",
|
|
"claude-code",
|
|
"qwen-code",
|
|
"azure",
|
|
]
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_provider_defaults_groq() {
|
|
let d = provider_defaults("groq").unwrap();
|
|
assert_eq!(d.base_url, "https://api.groq.com/openai/v1");
|
|
assert_eq!(d.api_key_env, "GROQ_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_openrouter() {
|
|
let d = provider_defaults("openrouter").unwrap();
|
|
assert_eq!(d.base_url, "https://openrouter.ai/api/v1");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_ollama() {
|
|
let d = provider_defaults("ollama").unwrap();
|
|
assert!(!d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_unknown_provider_returns_none() {
|
|
assert!(provider_defaults("nonexistent").is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_custom_provider_with_base_url() {
|
|
let config = DriverConfig {
|
|
provider: "my-custom-llm".to_string(),
|
|
api_key: Some("test".to_string()),
|
|
base_url: Some("http://localhost:9999/v1".to_string()),
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(driver.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_unknown_provider_no_url_errors() {
|
|
let config = DriverConfig {
|
|
provider: "nonexistent".to_string(),
|
|
api_key: None,
|
|
base_url: None,
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(driver.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_gemini() {
|
|
let d = provider_defaults("gemini").unwrap();
|
|
assert_eq!(d.base_url, "https://generativelanguage.googleapis.com");
|
|
assert_eq!(d.api_key_env, "GEMINI_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_google_alias() {
|
|
let d = provider_defaults("google").unwrap();
|
|
assert_eq!(d.base_url, "https://generativelanguage.googleapis.com");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_known_providers_list() {
|
|
let providers = known_providers();
|
|
assert!(providers.contains(&"groq"));
|
|
assert!(providers.contains(&"openrouter"));
|
|
assert!(providers.contains(&"anthropic"));
|
|
assert!(providers.contains(&"gemini"));
|
|
// New providers
|
|
assert!(providers.contains(&"perplexity"));
|
|
assert!(providers.contains(&"cohere"));
|
|
assert!(providers.contains(&"ai21"));
|
|
assert!(providers.contains(&"cerebras"));
|
|
assert!(providers.contains(&"sambanova"));
|
|
assert!(providers.contains(&"huggingface"));
|
|
assert!(providers.contains(&"xai"));
|
|
assert!(providers.contains(&"replicate"));
|
|
assert!(providers.contains(&"github-copilot"));
|
|
assert!(providers.contains(&"moonshot"));
|
|
assert!(providers.contains(&"qwen"));
|
|
assert!(providers.contains(&"minimax"));
|
|
assert!(providers.contains(&"zhipu"));
|
|
assert!(providers.contains(&"zhipu_coding"));
|
|
assert!(providers.contains(&"zai"));
|
|
assert!(providers.contains(&"kimi_coding"));
|
|
assert!(providers.contains(&"qianfan"));
|
|
assert!(providers.contains(&"volcengine"));
|
|
assert!(providers.contains(&"chutes"));
|
|
assert!(providers.contains(&"nvidia"));
|
|
assert!(providers.contains(&"codex"));
|
|
assert!(providers.contains(&"claude-code"));
|
|
assert!(providers.contains(&"qwen-code"));
|
|
assert!(providers.contains(&"azure"));
|
|
assert_eq!(providers.len(), 37);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_perplexity() {
|
|
let d = provider_defaults("perplexity").unwrap();
|
|
assert_eq!(d.base_url, "https://api.perplexity.ai");
|
|
assert_eq!(d.api_key_env, "PERPLEXITY_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_xai() {
|
|
let d = provider_defaults("xai").unwrap();
|
|
assert_eq!(d.base_url, "https://api.x.ai/v1");
|
|
assert_eq!(d.api_key_env, "XAI_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_cohere() {
|
|
let d = provider_defaults("cohere").unwrap();
|
|
assert_eq!(d.base_url, "https://api.cohere.com/v2");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_cerebras() {
|
|
let d = provider_defaults("cerebras").unwrap();
|
|
assert_eq!(d.base_url, "https://api.cerebras.ai/v1");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_huggingface() {
|
|
let d = provider_defaults("huggingface").unwrap();
|
|
assert_eq!(d.base_url, "https://api-inference.huggingface.co/v1");
|
|
assert_eq!(d.api_key_env, "HF_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_nvidia_provider_with_env_key() {
|
|
// NVIDIA NIM is a known provider — set API key and verify driver creation succeeds.
|
|
let unique_key = "test-nvidia-key-12345";
|
|
std::env::set_var("NVIDIA_API_KEY", unique_key);
|
|
let config = DriverConfig {
|
|
provider: "nvidia".to_string(),
|
|
api_key: None, // picked up from env via provider_defaults
|
|
base_url: None,
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(
|
|
driver.is_ok(),
|
|
"NVIDIA provider with env var should succeed"
|
|
);
|
|
std::env::remove_var("NVIDIA_API_KEY");
|
|
}
|
|
|
|
#[test]
|
|
fn test_nvidia_provider_no_key_errors() {
|
|
// NVIDIA NIM provider with no API key should error.
|
|
let config = DriverConfig {
|
|
provider: "nvidia".to_string(),
|
|
api_key: None,
|
|
base_url: None,
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(driver.is_err());
|
|
}
|
|
|
|
#[test]
|
|
fn test_custom_provider_key_no_url_helpful_error() {
|
|
// Custom provider with key set (via env) but no base_url should give helpful error.
|
|
let unique_key = "test-custom-key-67890";
|
|
std::env::set_var("MYCUSTOM_API_KEY", unique_key);
|
|
let config = DriverConfig {
|
|
provider: "mycustom".to_string(),
|
|
api_key: None,
|
|
base_url: None,
|
|
skip_permissions: true,
|
|
};
|
|
let result = create_driver(&config);
|
|
assert!(result.is_err());
|
|
let err = result.err().unwrap().to_string();
|
|
assert!(
|
|
err.contains("base_url"),
|
|
"Error should mention base_url: {}",
|
|
err
|
|
);
|
|
std::env::remove_var("MYCUSTOM_API_KEY");
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_kimi_coding() {
|
|
let d = provider_defaults("kimi_coding").unwrap();
|
|
assert_eq!(d.base_url, "https://api.kimi.com/coding");
|
|
assert_eq!(d.api_key_env, "KIMI_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_custom_provider_explicit_key_with_url() {
|
|
// When api_key is explicitly passed, it should be used regardless of env var.
|
|
let config = DriverConfig {
|
|
provider: "my-custom-provider".to_string(),
|
|
api_key: Some("explicit-key".to_string()),
|
|
base_url: Some("https://api.example.com/v1".to_string()),
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(driver.is_ok());
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_azure() {
|
|
let d = provider_defaults("azure").unwrap();
|
|
assert_eq!(d.base_url, ""); // Azure requires user-supplied URL
|
|
assert_eq!(d.api_key_env, "AZURE_OPENAI_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_provider_defaults_azure_openai_alias() {
|
|
let d = provider_defaults("azure-openai").unwrap();
|
|
assert_eq!(d.api_key_env, "AZURE_OPENAI_API_KEY");
|
|
assert!(d.key_required);
|
|
}
|
|
|
|
#[test]
|
|
fn test_azure_driver_creation_with_key_and_url() {
|
|
let config = DriverConfig {
|
|
provider: "azure".to_string(),
|
|
api_key: Some("test-azure-key".to_string()),
|
|
base_url: Some(
|
|
"https://myresource.openai.azure.com/openai/deployments".to_string(),
|
|
),
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(driver.is_ok(), "Azure driver with key + URL should succeed");
|
|
}
|
|
|
|
#[test]
|
|
fn test_azure_driver_no_key_errors() {
|
|
let config = DriverConfig {
|
|
provider: "azure".to_string(),
|
|
api_key: None,
|
|
base_url: Some(
|
|
"https://myresource.openai.azure.com/openai/deployments".to_string(),
|
|
),
|
|
skip_permissions: true,
|
|
};
|
|
let result = create_driver(&config);
|
|
assert!(result.is_err(), "Azure driver without key should error");
|
|
let err = result.err().unwrap().to_string();
|
|
assert!(
|
|
err.contains("AZURE_OPENAI_API_KEY"),
|
|
"Error should mention AZURE_OPENAI_API_KEY: {}",
|
|
err
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_azure_driver_no_url_errors() {
|
|
let config = DriverConfig {
|
|
provider: "azure".to_string(),
|
|
api_key: Some("test-azure-key".to_string()),
|
|
base_url: None,
|
|
skip_permissions: true,
|
|
};
|
|
let result = create_driver(&config);
|
|
assert!(result.is_err(), "Azure driver without URL should error");
|
|
let err = result.err().unwrap().to_string();
|
|
assert!(
|
|
err.contains("base_url"),
|
|
"Error should mention base_url: {}",
|
|
err
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_azure_openai_alias_driver_creation() {
|
|
let config = DriverConfig {
|
|
provider: "azure-openai".to_string(),
|
|
api_key: Some("test-azure-key".to_string()),
|
|
base_url: Some(
|
|
"https://myresource.openai.azure.com/openai/deployments".to_string(),
|
|
),
|
|
skip_permissions: true,
|
|
};
|
|
let driver = create_driver(&config);
|
|
assert!(
|
|
driver.is_ok(),
|
|
"azure-openai alias should create driver successfully"
|
|
);
|
|
}
|
|
}
|