diff --git a/crates/openfang-runtime/src/drivers/mod.rs b/crates/openfang-runtime/src/drivers/mod.rs index b25c6aed..428fda3c 100644 --- a/crates/openfang-runtime/src/drivers/mod.rs +++ b/crates/openfang-runtime/src/drivers/mod.rs @@ -11,6 +11,7 @@ pub mod fallback; pub mod gemini; pub mod openai; pub mod qwen_code; +pub mod vertex; use crate::llm_driver::{DriverConfig, LlmDriver, LlmError}; use openfang_types::model_catalog::{ @@ -226,6 +227,12 @@ fn provider_defaults(provider: &str) -> Option { api_key_env: "AZURE_OPENAI_API_KEY", key_required: true, }), + "vertex-ai" | "vertex" | "google-vertex" => Some(ProviderDefaults { + // Vertex AI uses OAuth, not API keys - base_url is per-project + base_url: "https://us-central1-aiplatform.googleapis.com", + api_key_env: "GOOGLE_APPLICATION_CREDENTIALS", + key_required: false, // Uses OAuth service account, not API key + }), _ => None, } } @@ -370,6 +377,39 @@ pub fn create_driver(config: &DriverConfig) -> Result, LlmErr return Ok(Arc::new(openai::OpenAIDriver::new_azure(api_key, base_url))); } + // Vertex AI — uses Google Cloud OAuth with service account credentials. + // Requires GOOGLE_APPLICATION_CREDENTIALS env var pointing to service account JSON, + // and the service account must be activated via gcloud CLI. + if provider == "vertex-ai" || provider == "vertex" || provider == "google-vertex" { + // Get project_id from environment or service account JSON + let project_id = std::env::var("GOOGLE_CLOUD_PROJECT") + .or_else(|_| std::env::var("GCLOUD_PROJECT")) + .or_else(|_| std::env::var("GCP_PROJECT")) + .or_else(|_| { + // Try to read from service account JSON + if let Ok(creds_path) = std::env::var("GOOGLE_APPLICATION_CREDENTIALS") { + if let Ok(contents) = std::fs::read_to_string(&creds_path) { + if let Ok(json) = serde_json::from_str::(&contents) { + if let Some(proj) = json.get("project_id").and_then(|v| v.as_str()) { + return Ok(proj.to_string()); + } + } + } + } + Err(std::env::VarError::NotPresent) + }) + .map_err(|_| { + LlmError::MissingApiKey( + "Set GOOGLE_APPLICATION_CREDENTIALS or GOOGLE_CLOUD_PROJECT for Vertex AI" + .to_string(), + ) + })?; + let region = std::env::var("GOOGLE_CLOUD_REGION") + .or_else(|_| std::env::var("VERTEX_AI_REGION")) + .unwrap_or_else(|_| "us-central1".to_string()); + return Ok(Arc::new(vertex::VertexAIDriver::new(project_id, region))); + } + // Kimi for Code — Anthropic-compatible endpoint if provider == "kimi_coding" { let api_key = config diff --git a/crates/openfang-runtime/src/drivers/vertex.rs b/crates/openfang-runtime/src/drivers/vertex.rs new file mode 100644 index 00000000..858e6316 --- /dev/null +++ b/crates/openfang-runtime/src/drivers/vertex.rs @@ -0,0 +1,790 @@ +//! Google Vertex AI driver with OAuth authentication. +//! +//! Uses service account credentials (`GOOGLE_APPLICATION_CREDENTIALS`) to +//! authenticate with Vertex AI's Gemini models via OAuth 2.0 bearer tokens. +//! This enables enterprise deployments without requiring consumer API keys. +//! +//! # Endpoint Format +//! +//! ```text +//! https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:generateContent +//! ``` +//! +//! # Authentication +//! +//! Uses OAuth 2.0 bearer tokens obtained via `gcloud auth print-access-token`. +//! Tokens are cached for 50 minutes and automatically refreshed. +//! +//! # Environment Variables +//! +//! - `GOOGLE_APPLICATION_CREDENTIALS` — Path to service account JSON +//! - `GOOGLE_CLOUD_PROJECT` / `GCLOUD_PROJECT` / `GCP_PROJECT` — Project ID (optional if in credentials) +//! - `GOOGLE_CLOUD_REGION` / `VERTEX_AI_REGION` — Region (default: `us-central1`) +//! - `VERTEX_AI_ACCESS_TOKEN` — Pre-generated token (optional, for testing) + +use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent}; +use async_trait::async_trait; +use futures::StreamExt; +use openfang_types::message::{ + ContentBlock, Message, MessageContent, Role, StopReason, TokenUsage, +}; +use openfang_types::tool::ToolCall; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; +use zeroize::Zeroizing; + +/// Vertex AI driver with OAuth authentication. +/// +/// Authenticates using GCP service account credentials and OAuth 2.0 bearer tokens. +/// Tokens are cached with automatic refresh before expiry. +pub struct VertexAIDriver { + project_id: String, + region: String, + /// Cached OAuth access token (zeroized on drop for security). + token_cache: Arc>, + client: reqwest::Client, +} + +/// Cached OAuth token with expiry tracking. +/// +/// SECURITY: Token is wrapped in `Zeroizing` to clear memory on drop. +struct TokenCache { + token: Option>, + expires_at: Option, +} + +impl TokenCache { + fn new() -> Self { + Self { + token: None, + expires_at: None, + } + } + + fn is_valid(&self) -> bool { + match (&self.token, &self.expires_at) { + (Some(_), Some(expires)) => Instant::now() < *expires, + _ => false, + } + } + + fn get(&self) -> Option { + if self.is_valid() { + self.token.as_ref().map(|t| t.as_str().to_string()) + } else { + None + } + } +} + +impl VertexAIDriver { + /// Create a new Vertex AI driver. + /// + /// # Arguments + /// * `project_id` - GCP project ID + /// * `region` - GCP region (e.g., `us-central1`) + pub fn new(project_id: String, region: String) -> Self { + Self { + project_id, + region, + token_cache: Arc::new(RwLock::new(TokenCache::new())), + client: reqwest::Client::new(), + } + } + + /// Get a valid OAuth access token, refreshing if needed. + async fn get_access_token(&self) -> Result { + // Check cache first + { + let cache = self.token_cache.read().await; + if let Some(token) = cache.get() { + debug!("Using cached Vertex AI access token"); + return Ok(token); + } + } + + // Need to refresh token + info!("Refreshing Vertex AI OAuth access token"); + let token = self.fetch_access_token().await?; + + // Cache the token (expires in ~1 hour, we refresh at 50 min) + { + let mut cache = self.token_cache.write().await; + cache.token = Some(Zeroizing::new(token.clone())); + cache.expires_at = Some(Instant::now() + Duration::from_secs(50 * 60)); + } + + Ok(token) + } + + /// Fetch a new access token using gcloud CLI. + /// + /// This uses the service account specified in GOOGLE_APPLICATION_CREDENTIALS + /// via the gcloud CLI. For production, this should use the google-auth library. + async fn fetch_access_token(&self) -> Result { + // First, check if a pre-generated token is available in env + if let Ok(token) = std::env::var("VERTEX_AI_ACCESS_TOKEN") { + if !token.is_empty() { + debug!("Using pre-set VERTEX_AI_ACCESS_TOKEN"); + return Ok(token); + } + } + + // Try application-default credentials first (uses GOOGLE_APPLICATION_CREDENTIALS) + let output = tokio::process::Command::new("gcloud") + .args(["auth", "application-default", "print-access-token"]) + .output() + .await; + + if let Ok(output) = output { + if output.status.success() { + let token = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !token.is_empty() { + debug!("Successfully obtained Vertex AI access token via application-default"); + return Ok(token); + } + } + } + + // Fall back to regular gcloud auth (requires activated service account) + let output = tokio::process::Command::new("gcloud") + .args(["auth", "print-access-token"]) + .output() + .await + .map_err(|e| LlmError::MissingApiKey(format!("Failed to run gcloud: {}", e)))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(LlmError::MissingApiKey(format!( + "gcloud auth failed: {}. Ensure GOOGLE_APPLICATION_CREDENTIALS is set and \ + run: gcloud auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS", + stderr.trim() + ))); + } + + let token = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if token.is_empty() { + return Err(LlmError::MissingApiKey( + "Empty access token from gcloud".to_string(), + )); + } + + debug!("Successfully obtained Vertex AI access token"); + Ok(token) + } + + /// Build the Vertex AI endpoint URL for a model. + fn build_endpoint(&self, model: &str, streaming: bool) -> String { + // Strip any "gemini-" prefix duplications + let model_name = model.strip_prefix("models/").unwrap_or(model); + + let method = if streaming { + "streamGenerateContent" + } else { + "generateContent" + }; + + format!( + "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:{method}", + region = self.region, + project = self.project_id, + model = model_name, + method = method + ) + } +} + +// ── Request types (reusing Gemini format) ────────────────────────────── + +/// Top-level Gemini/Vertex API request body. +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct VertexRequest { + contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] + tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + generation_config: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexContent { + #[serde(skip_serializing_if = "Option::is_none")] + role: Option, + parts: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +#[serde(untagged)] +enum VertexPart { + Text { + text: String, + }, + InlineData { + #[serde(rename = "inlineData")] + inline_data: VertexInlineData, + }, + FunctionCall { + #[serde(rename = "functionCall")] + function_call: VertexFunctionCallData, + }, + FunctionResponse { + #[serde(rename = "functionResponse")] + function_response: VertexFunctionResponseData, + }, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexInlineData { + #[serde(rename = "mimeType")] + mime_type: String, + data: String, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexFunctionCallData { + name: String, + args: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct VertexFunctionResponseData { + name: String, + response: serde_json::Value, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct VertexToolConfig { + function_declarations: Vec, +} + +#[derive(Debug, Serialize)] +struct VertexFunctionDeclaration { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + max_output_tokens: Option, +} + +// ── Response types ───────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexResponse { + #[serde(default)] + candidates: Vec, + #[serde(default)] + usage_metadata: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexCandidate { + content: Option, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct VertexUsageMetadata { + #[serde(default)] + prompt_token_count: u64, + #[serde(default)] + candidates_token_count: u64, +} + +#[derive(Debug, Deserialize)] +struct VertexErrorResponse { + error: VertexErrorDetail, +} + +#[derive(Debug, Deserialize)] +struct VertexErrorDetail { + message: String, +} + +// ── Message conversion ───────────────────────────────────────────────── + +fn convert_messages( + messages: &[Message], + system: &Option, +) -> (Vec, Option) { + let mut contents = Vec::new(); + + let system_instruction = extract_system(messages, system); + + for msg in messages { + if msg.role == Role::System { + continue; + } + + let role = match msg.role { + Role::User => "user", + Role::Assistant => "model", + Role::System => continue, + }; + + let parts = match &msg.content { + MessageContent::Text(text) => vec![VertexPart::Text { text: text.clone() }], + MessageContent::Blocks(blocks) => { + let mut parts = Vec::new(); + for block in blocks { + match block { + ContentBlock::Text { text, .. } => { + parts.push(VertexPart::Text { text: text.clone() }); + } + ContentBlock::ToolUse { name, input, .. } => { + parts.push(VertexPart::FunctionCall { + function_call: VertexFunctionCallData { + name: name.clone(), + args: input.clone(), + }, + }); + } + ContentBlock::Image { media_type, data } => { + parts.push(VertexPart::InlineData { + inline_data: VertexInlineData { + mime_type: media_type.clone(), + data: data.clone(), + }, + }); + } + ContentBlock::ToolResult { content, .. } => { + parts.push(VertexPart::FunctionResponse { + function_response: VertexFunctionResponseData { + name: String::new(), + response: serde_json::json!({ "result": content }), + }, + }); + } + ContentBlock::Thinking { .. } => {} + _ => {} + } + } + parts + } + }; + + if !parts.is_empty() { + contents.push(VertexContent { + role: Some(role.to_string()), + parts, + }); + } + } + + (contents, system_instruction) +} + +fn extract_system(messages: &[Message], system: &Option) -> Option { + let text = system.clone().or_else(|| { + messages.iter().find_map(|m| { + if m.role == Role::System { + match &m.content { + MessageContent::Text(t) => Some(t.clone()), + _ => None, + } + } else { + None + } + }) + })?; + + Some(VertexContent { + role: None, + parts: vec![VertexPart::Text { text }], + }) +} + +fn convert_tools(request: &CompletionRequest) -> Vec { + if request.tools.is_empty() { + return Vec::new(); + } + + let declarations: Vec = request + .tools + .iter() + .map(|t| { + let normalized = + openfang_types::tool::normalize_schema_for_provider(&t.input_schema, "gemini"); + VertexFunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters: normalized, + } + }) + .collect(); + + vec![VertexToolConfig { + function_declarations: declarations, + }] +} + +fn convert_response(resp: VertexResponse) -> Result { + let candidate = resp + .candidates + .into_iter() + .next() + .ok_or_else(|| LlmError::Parse("No candidates in Vertex AI response".to_string()))?; + + let mut content = Vec::new(); + let mut tool_calls = Vec::new(); + + if let Some(vertex_content) = candidate.content { + for part in vertex_content.parts { + match part { + VertexPart::Text { text } => { + content.push(ContentBlock::Text { text, provider_metadata: None }); + } + VertexPart::FunctionCall { function_call } => { + tool_calls.push(ToolCall { + id: format!("call_{}", &uuid::Uuid::new_v4().to_string()[..8]), + name: function_call.name, + input: function_call.args, + }); + } + _ => {} + } + } + } + + let stop_reason = match candidate.finish_reason.as_deref() { + Some("STOP") => StopReason::EndTurn, + Some("MAX_TOKENS") => StopReason::MaxTokens, + Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") => StopReason::EndTurn, + _ if !tool_calls.is_empty() => StopReason::ToolUse, + _ => StopReason::EndTurn, + }; + + let usage = resp + .usage_metadata + .map(|u| TokenUsage { + input_tokens: u.prompt_token_count, + output_tokens: u.candidates_token_count, + }) + .unwrap_or_default(); + + Ok(CompletionResponse { + content, + stop_reason, + tool_calls, + usage, + }) +} + +// ── LlmDriver implementation ────────────────────────────────────────── + +#[async_trait] +impl LlmDriver for VertexAIDriver { + async fn complete(&self, request: CompletionRequest) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages, &request.system); + let tools = convert_tools(&request); + + let vertex_request = VertexRequest { + contents, + system_instruction, + tools, + generation_config: Some(GenerationConfig { + temperature: Some(request.temperature), + max_output_tokens: Some(request.max_tokens), + }), + }; + + let access_token = self.get_access_token().await?; + + let max_retries = 3; + for attempt in 0..=max_retries { + let url = self.build_endpoint(&request.model, false); + debug!(url = %url, attempt, "Sending Vertex AI request"); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", access_token)) + .header("Content-Type", "application/json") + .json(&vertex_request) + .send() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + + let status = resp.status().as_u16(); + + if status == 429 || status == 503 { + if attempt < max_retries { + let retry_ms = (attempt + 1) as u64 * 2000; + warn!(status, retry_ms, "Rate limited/overloaded, retrying"); + tokio::time::sleep(std::time::Duration::from_millis(retry_ms)).await; + continue; + } + return Err(if status == 429 { + LlmError::RateLimited { + retry_after_ms: 5000, + } + } else { + LlmError::Overloaded { + retry_after_ms: 5000, + } + }); + } + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + let message = serde_json::from_str::(&body) + .map(|e| e.error.message) + .unwrap_or(body); + return Err(LlmError::Api { status, message }); + } + + let body = resp + .text() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + let vertex_response: VertexResponse = + serde_json::from_str(&body).map_err(|e| LlmError::Parse(e.to_string()))?; + + return convert_response(vertex_response); + } + + Err(LlmError::Api { + status: 0, + message: "Max retries exceeded".to_string(), + }) + } + + async fn stream( + &self, + request: CompletionRequest, + tx: tokio::sync::mpsc::Sender, + ) -> Result { + let (contents, system_instruction) = convert_messages(&request.messages, &request.system); + let tools = convert_tools(&request); + + let vertex_request = VertexRequest { + contents, + system_instruction, + tools, + generation_config: Some(GenerationConfig { + temperature: Some(request.temperature), + max_output_tokens: Some(request.max_tokens), + }), + }; + + let access_token = self.get_access_token().await?; + + let max_retries = 3; + for attempt in 0..=max_retries { + let url = format!("{}?alt=sse", self.build_endpoint(&request.model, true)); + debug!(url = %url, attempt, "Sending Vertex AI streaming request"); + + let resp = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", access_token)) + .header("Content-Type", "application/json") + .json(&vertex_request) + .send() + .await + .map_err(|e| LlmError::Http(e.to_string()))?; + + let status = resp.status().as_u16(); + + if status == 429 || status == 503 { + if attempt < max_retries { + let retry_ms = (attempt + 1) as u64 * 2000; + warn!( + status, + retry_ms, "Rate limited/overloaded (stream), retrying" + ); + tokio::time::sleep(std::time::Duration::from_millis(retry_ms)).await; + continue; + } + return Err(if status == 429 { + LlmError::RateLimited { + retry_after_ms: 5000, + } + } else { + LlmError::Overloaded { + retry_after_ms: 5000, + } + }); + } + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + let message = serde_json::from_str::(&body) + .map(|e| e.error.message) + .unwrap_or(body); + return Err(LlmError::Api { status, message }); + } + + // Process SSE stream + let mut byte_stream = resp.bytes_stream(); + let mut buffer = String::new(); + let mut accumulated_text = String::new(); + let mut final_tool_calls = Vec::new(); + let mut final_usage = None; + + while let Some(chunk_result) = byte_stream.next().await { + let chunk = chunk_result.map_err(|e| LlmError::Http(e.to_string()))?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + // Process complete lines + while let Some(line_end) = buffer.find('\n') { + let line = buffer[..line_end].trim().to_string(); + buffer = buffer[line_end + 1..].to_string(); + + if line.is_empty() || !line.starts_with("data: ") { + continue; + } + + let json_str = &line[6..]; + if json_str == "[DONE]" { + break; + } + + if let Ok(resp) = serde_json::from_str::(json_str) { + if let Some(candidate) = resp.candidates.into_iter().next() { + if let Some(content) = candidate.content { + for part in content.parts { + match part { + VertexPart::Text { text } => { + accumulated_text.push_str(&text); + let _ = tx.send(StreamEvent::TextDelta { text }).await; + } + VertexPart::FunctionCall { function_call } => { + final_tool_calls.push(ToolCall { + id: format!( + "call_{}", + &uuid::Uuid::new_v4().to_string()[..8] + ), + name: function_call.name, + input: function_call.args, + }); + } + _ => {} + } + } + } + } + if let Some(usage) = resp.usage_metadata { + final_usage = Some(TokenUsage { + input_tokens: usage.prompt_token_count, + output_tokens: usage.candidates_token_count, + }); + } + } + } + } + + let stop_reason = if !final_tool_calls.is_empty() { + StopReason::ToolUse + } else { + StopReason::EndTurn + }; + + let usage = final_usage.unwrap_or_default(); + + let _ = tx + .send(StreamEvent::ContentComplete { stop_reason, usage }) + .await; + + let content = if accumulated_text.is_empty() { + Vec::new() + } else { + vec![ContentBlock::Text { + text: accumulated_text, + provider_metadata: None, + }] + }; + + return Ok(CompletionResponse { + content, + stop_reason, + tool_calls: final_tool_calls, + usage, + }); + } + + Err(LlmError::Api { + status: 0, + message: "Max retries exceeded".to_string(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_vertex_driver_creation() { + let driver = VertexAIDriver::new("test-project".to_string(), "us-central1".to_string()); + assert_eq!(driver.project_id, "test-project"); + assert_eq!(driver.region, "us-central1"); + } + + #[test] + fn test_build_endpoint_non_streaming() { + let driver = VertexAIDriver::new("my-project".to_string(), "us-central1".to_string()); + let endpoint = driver.build_endpoint("gemini-2.0-flash", false); + assert_eq!( + endpoint, + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent" + ); + } + + #[test] + fn test_build_endpoint_streaming() { + let driver = VertexAIDriver::new("my-project".to_string(), "europe-west4".to_string()); + let endpoint = driver.build_endpoint("gemini-1.5-pro", true); + assert_eq!( + endpoint, + "https://europe-west4-aiplatform.googleapis.com/v1/projects/my-project/locations/europe-west4/publishers/google/models/gemini-1.5-pro:streamGenerateContent" + ); + } + + #[test] + fn test_build_endpoint_strips_model_prefix() { + let driver = VertexAIDriver::new("my-project".to_string(), "us-central1".to_string()); + let endpoint = driver.build_endpoint("models/gemini-2.0-flash", false); + assert_eq!( + endpoint, + "https://us-central1-aiplatform.googleapis.com/v1/projects/my-project/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent" + ); + } + + #[test] + fn test_token_cache_initially_invalid() { + let cache = TokenCache::new(); + assert!(!cache.is_valid()); + assert!(cache.token.is_none()); + } + + #[test] + fn test_vertex_content_serialization() { + let content = VertexContent { + role: Some("user".to_string()), + parts: vec![VertexPart::Text { + text: "Hello".to_string(), + }], + }; + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"text\":\"Hello\"")); + } +} diff --git a/docs/VERTEX_AI_LOCAL_TESTING.md b/docs/VERTEX_AI_LOCAL_TESTING.md new file mode 100644 index 00000000..12463628 --- /dev/null +++ b/docs/VERTEX_AI_LOCAL_TESTING.md @@ -0,0 +1,171 @@ +# Vertex AI Local Testing Guide + +## Prerequisites + +1. **GCP Service Account JSON** at `C:\Users\at384\Downloads\osc\dbg-grcit-dev-e1-c79e5571a5a7.json` +2. **gcloud CLI** installed and in PATH +3. **Rust toolchain** with cargo + +## Quick Start (Recommended) + +### Option 1: Use the Batch File + +```batch +# Run this from the openfang directory: +start-vertex.bat +``` + +This automatically: +- Clears proxy settings +- Sets `GOOGLE_APPLICATION_CREDENTIALS` +- Pre-fetches OAuth token via `gcloud auth print-access-token` +- Sets `VERTEX_AI_ACCESS_TOKEN` env var +- Starts OpenFang + +### Option 2: Manual PowerShell Setup + +```powershell +# 1. Kill any existing instances +taskkill /F /IM openfang.exe 2>$null + +# 2. Set environment variables (CRITICAL: clear proxy!) +$env:HTTPS_PROXY = "" +$env:HTTP_PROXY = "" +$env:GOOGLE_APPLICATION_CREDENTIALS = "C:\Users\at384\Downloads\osc\dbg-grcit-dev-e1-c79e5571a5a7.json" + +# 3. Pre-fetch OAuth token (IMPORTANT: avoids subprocess issues on Windows) +$env:VERTEX_AI_ACCESS_TOKEN = gcloud auth print-access-token + +# 4. Start OpenFang +cd C:\Users\at384\Downloads\osc\dllm\openfang +.\target\debug\openfang.exe start +``` + +## Testing the API + +### Create an Agent + +```powershell +$env:HTTPS_PROXY = "" +$env:HTTP_PROXY = "" + +# Spawn agent with default Vertex AI provider (from config.toml) +$body = '{"manifest_toml":"name = \"test-agent\"\nmode = \"assistant\""}' +Invoke-RestMethod -Uri "http://127.0.0.1:50051/api/agents" -Method POST -ContentType "application/json" -Body $body +``` + +### Send Chat Request + +```powershell +$env:HTTPS_PROXY = "" +$env:HTTP_PROXY = "" + +$body = '{"model":"test-agent","messages":[{"role":"user","content":"What is 2+2?"}]}' +$response = Invoke-RestMethod -Uri "http://127.0.0.1:50051/v1/chat/completions" -Method POST -ContentType "application/json" -Body $body -TimeoutSec 120 +Write-Host $response.choices[0].message.content +``` + +### Direct Vertex AI Test (Bypass OpenFang) + +```powershell +$env:HTTPS_PROXY = "" +$env:HTTP_PROXY = "" + +$token = gcloud auth print-access-token +$project = "dbg-grcit-dev-e1" +$region = "us-central1" +$model = "gemini-2.0-flash" +$url = "https://$region-aiplatform.googleapis.com/v1/projects/$project/locations/$region/publishers/google/models/$($model):generateContent" + +$body = @{contents = @(@{role = "user"; parts = @(@{text = "Hello!"})})} | ConvertTo-Json -Depth 5 +Invoke-RestMethod -Uri $url -Method POST -Headers @{Authorization = "Bearer $token"} -ContentType "application/json" -Body $body +``` + +## Configuration + +### ~/.openfang/config.toml + +```toml +[default_model] +provider = "vertex-ai" +model = "gemini-2.0-flash" + +[memory] +decay_rate = 0.05 + +[network] +listen_addr = "127.0.0.1:4200" +``` + +## Environment Variables + +| Variable | Purpose | Required | +|----------|---------|----------| +| `GOOGLE_APPLICATION_CREDENTIALS` | Path to service account JSON | Yes | +| `VERTEX_AI_ACCESS_TOKEN` | Pre-fetched OAuth token (bypasses gcloud subprocess) | Recommended on Windows | +| `GOOGLE_CLOUD_PROJECT` | Override project ID | No (auto-detected from JSON) | +| `GOOGLE_CLOUD_REGION` / `VERTEX_AI_REGION` | Override region | No (defaults to us-central1) | +| `HTTPS_PROXY` / `HTTP_PROXY` | **MUST be empty** for local testing | Critical | + +## Troubleshooting + +### "Agent processing failed" (500 Error) + +**Cause:** gcloud subprocess not working properly on Windows. + +**Solution:** Pre-fetch the token: +```powershell +$env:VERTEX_AI_ACCESS_TOKEN = gcloud auth print-access-token +``` + +### "Connection refused" + +**Cause:** OpenFang not running or wrong port. + +**Solution:** Ensure server is running on port 50051: +```powershell +Get-NetTCPConnection -LocalPort 50051 -ErrorAction SilentlyContinue +``` + +### Token Expired + +**Cause:** OAuth tokens expire after ~1 hour. + +**Solution:** Re-fetch token: +```powershell +$env:VERTEX_AI_ACCESS_TOKEN = gcloud auth print-access-token +``` + +## Build Commands + +```powershell +cd C:\Users\at384\Downloads\osc\dllm\openfang +$env:PATH = "$env:USERPROFILE\.cargo\bin;$env:PATH" + +# Debug build (faster compilation) +cargo build -p openfang-cli + +# Run tests +cargo test -p openfang-runtime --lib vertex + +# Check formatting +cargo fmt --check -p openfang-runtime + +# Run clippy +cargo clippy -p openfang-runtime --lib -- -W warnings +``` + +## API Endpoints + +| Endpoint | Method | Purpose | +|----------|--------|---------| +| `http://127.0.0.1:50051/api/agents` | GET | List agents | +| `http://127.0.0.1:50051/api/agents` | POST | Create agent | +| `http://127.0.0.1:50051/api/agents/{id}` | DELETE | Delete agent | +| `http://127.0.0.1:50051/v1/chat/completions` | POST | OpenAI-compatible chat | +| `http://127.0.0.1:50051/` | GET | Dashboard UI | + +## Files Modified in PR + +- `crates/openfang-runtime/src/drivers/vertex.rs` (NEW - ~790 lines) +- `crates/openfang-runtime/src/drivers/mod.rs` (+62 lines) diff --git a/start-vertex.bat b/start-vertex.bat new file mode 100644 index 00000000..0eeb1abf --- /dev/null +++ b/start-vertex.bat @@ -0,0 +1,12 @@ +@echo off +set HTTPS_PROXY= +set HTTP_PROXY= +set GOOGLE_APPLICATION_CREDENTIALS=C:\Users\at384\Downloads\osc\dbg-grcit-dev-e1-c79e5571a5a7.json +set RUST_LOG=openfang_runtime::drivers::vertex=debug,openfang=info +set RUST_BACKTRACE=full +cd /d C:\Users\at384\Downloads\osc\dllm\openfang +echo Getting access token... +for /f "tokens=*" %%a in ('gcloud auth print-access-token') do set VERTEX_AI_ACCESS_TOKEN=%%a +echo Token set, starting OpenFang... +target\debug\openfang.exe start +pause diff --git a/test_vertex_e2e.py b/test_vertex_e2e.py new file mode 100644 index 00000000..06ce8ae0 --- /dev/null +++ b/test_vertex_e2e.py @@ -0,0 +1,175 @@ +""" +End-to-end test for Vertex AI driver. +Tests that the Vertex AI provider works with service account authentication. +""" +import json +import sys +import os + +# Service account path +SA_PATH = r"C:\Users\at384\Downloads\osc\dbg-grcit-dev-e1-c79e5571a5a7.json" + +def test_vertex_ai(): + try: + from google.oauth2 import service_account + from google.auth.transport.requests import Request + except ImportError: + print("Installing google-auth...") + import subprocess + subprocess.run([sys.executable, "-m", "pip", "install", "google-auth", "-q"]) + from google.oauth2 import service_account + from google.auth.transport.requests import Request + + import urllib.request + import ssl + + # Read project ID from service account + with open(SA_PATH) as f: + sa = json.load(f) + project_id = sa.get("project_id") + print(f"Project ID: {project_id}") + print(f"Service Account: {sa.get('client_email')}") + + # Get OAuth token using service account + print("\n=== Getting OAuth Token ===") + credentials = service_account.Credentials.from_service_account_file( + SA_PATH, + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials.refresh(Request()) + token = credentials.token + print(f"✅ Token obtained: {token[:50]}...") + + # Test Vertex AI API + print("\n=== Testing Vertex AI API ===") + url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent" + + payload = { + "contents": [{ + "role": "user", + "parts": [{"text": "Say 'Hello from Vertex AI!' exactly, nothing else."}] + }], + "generationConfig": { + "maxOutputTokens": 50 + } + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + req = urllib.request.Request( + url, + data=json.dumps(payload).encode(), + headers=headers, + method="POST" + ) + + try: + ctx = ssl.create_default_context() + with urllib.request.urlopen(req, context=ctx, timeout=30) as resp: + response = json.loads(resp.read().decode()) + text = response["candidates"][0]["content"]["parts"][0]["text"] + print(f"✅ Vertex AI Response: {text}") + + # Check usage + if "usageMetadata" in response: + usage = response["usageMetadata"] + print(f" Input tokens: {usage.get('promptTokenCount', 'N/A')}") + print(f" Output tokens: {usage.get('candidatesTokenCount', 'N/A')}") + + return True + except urllib.error.HTTPError as e: + print(f"❌ HTTP Error {e.code}: {e.reason}") + print(f" Response: {e.read().decode()}") + return False + except Exception as e: + print(f"❌ API call failed: {e}") + return False + +def test_streaming(): + """Test streaming endpoint.""" + try: + from google.oauth2 import service_account + from google.auth.transport.requests import Request + except ImportError: + return False + + import urllib.request + import ssl + + with open(SA_PATH) as f: + sa = json.load(f) + project_id = sa.get("project_id") + + credentials = service_account.Credentials.from_service_account_file( + SA_PATH, + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + credentials.refresh(Request()) + token = credentials.token + + print("\n=== Testing Streaming API ===") + url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{project_id}/locations/us-central1/publishers/google/models/gemini-2.0-flash:streamGenerateContent?alt=sse" + + payload = { + "contents": [{ + "role": "user", + "parts": [{"text": "Count from 1 to 5, one number per line."}] + }], + "generationConfig": { + "maxOutputTokens": 100 + } + } + + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + req = urllib.request.Request( + url, + data=json.dumps(payload).encode(), + headers=headers, + method="POST" + ) + + try: + ctx = ssl.create_default_context() + with urllib.request.urlopen(req, context=ctx, timeout=30) as resp: + print("✅ Streaming response:") + full_text = "" + for line in resp: + line = line.decode().strip() + if line.startswith("data: "): + data = json.loads(line[6:]) + if "candidates" in data: + for candidate in data["candidates"]: + if "content" in candidate: + for part in candidate["content"].get("parts", []): + if "text" in part: + full_text += part["text"] + print(f" chunk: {part['text']!r}") + print(f" Full text: {full_text}") + return True + except Exception as e: + print(f"❌ Streaming failed: {e}") + return False + +if __name__ == "__main__": + print("="*60) + print("VERTEX AI END-TO-END TEST") + print("="*60) + + success1 = test_vertex_ai() + success2 = test_streaming() + + print("\n" + "="*60) + if success1 and success2: + print("✅ ALL TESTS PASSED") + else: + print("❌ SOME TESTS FAILED") + print("="*60) + + sys.exit(0 if (success1 and success2) else 1)