[Feat] LiteLLM RAG API - Add support for Vertex RAG engine (#17117)

* add VertexAIVectorStoreOptions

* Revert "add VertexAIVectorStoreOptions"

This reverts commit b086adf10b.

* add VertexAIVectorStoreOptions

* add get_rag_ingestion_class

* add VertexAIRAGTransformation

* test ingestion

* docs vertex ai rag engine
This commit is contained in:
Ishaan Jaff
2025-11-26 15:49:04 -08:00
committed by GitHub
parent 32617d1e72
commit 379655e16b
8 changed files with 779 additions and 34 deletions

View File

@@ -4,9 +4,8 @@ All-in-one document ingestion pipeline: **Upload → Chunk → Embed → Vector
| Feature | Supported |
|---------|-----------|
| Cost Tracking | ❌ |
| Logging | ✅ |
| Supported Providers | `openai`, `bedrock`, `gemini` |
| Supported Providers | `openai`, `bedrock`, `vertex_ai`, `gemini` |
## Quick Start
@@ -50,9 +49,9 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
}"
```
### Gemini
### Vertex AI RAG Engine
```bash showLineNumbers title="Ingest to Gemini File Search"
```bash showLineNumbers title="Ingest to Vertex AI RAG Corpus"
curl -X POST "http://localhost:4000/v1/rag/ingest" \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
@@ -64,38 +63,14 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
},
\"ingest_options\": {
\"vector_store\": {
\"custom_llm_provider\": \"gemini\"
\"custom_llm_provider\": \"vertex_ai\",
\"vector_store_id\": \"your-corpus-id\",
\"gcs_bucket\": \"your-gcs-bucket\"
}
}
}"
```
**With Custom Chunking:**
```bash showLineNumbers title="Ingest with custom chunking"
curl -X POST "http://localhost:4000/v1/rag/ingest" \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d '{
"file": {
"filename": "document.txt",
"content": "'$(base64 -i document.txt)'",
"content_type": "text/plain"
},
"ingest_options": {
"vector_store": {
"custom_llm_provider": "gemini"
},
"chunking_strategy": {
"white_space_config": {
"max_tokens_per_chunk": 200,
"max_overlap_tokens": 20
}
}
}
}'
```
## Response
```json
@@ -242,6 +217,26 @@ When `vector_store_id` is omitted, LiteLLM automatically creates:
- Data Source
:::
### vector_store (Vertex AI)
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `custom_llm_provider` | string | - | `"vertex_ai"` |
| `vector_store_id` | string | **required** | RAG corpus ID |
| `gcs_bucket` | string | **required** | GCS bucket for file uploads |
| `vertex_project` | string | env `VERTEXAI_PROJECT` | GCP project ID |
| `vertex_location` | string | `us-central1` | GCP region |
| `vertex_credentials` | string | ADC | Path to credentials JSON |
| `wait_for_import` | boolean | `true` | Wait for import to complete |
| `import_timeout` | integer | `600` | Timeout in seconds (if waiting) |
:::info Vertex AI Prerequisites
1. Create a RAG corpus in Vertex AI console or via API
2. Create a GCS bucket for file uploads
3. Authenticate via `gcloud auth application-default login`
4. Install: `pip install 'google-cloud-aiplatform>=1.60.0'`
:::
## Input Examples
### File (Base64)
@@ -271,3 +266,40 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
}'
```
## Chunking Strategy
Control how documents are split into chunks before embedding. Specify `chunking_strategy` in `ingest_options`.
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `chunk_size` | integer | `1000` | Maximum size of each chunk |
| `chunk_overlap` | integer | `200` | Overlap between consecutive chunks |
### Vertex AI RAG Engine
Vertex AI RAG Engine supports custom chunking via the `chunking_strategy` parameter. Chunks are processed server-side during import.
```bash showLineNumbers title="Vertex AI with custom chunking"
curl -X POST "http://localhost:4000/v1/rag/ingest" \
-H "Authorization: Bearer sk-1234" \
-H "Content-Type: application/json" \
-d "{
\"file\": {
\"filename\": \"document.txt\",
\"content\": \"$(base64 -i document.txt)\",
\"content_type\": \"text/plain\"
},
\"ingest_options\": {
\"chunking_strategy\": {
\"chunk_size\": 500,
\"chunk_overlap\": 100
},
\"vector_store\": {
\"custom_llm_provider\": \"vertex_ai\",
\"vector_store_id\": \"your-corpus-id\",
\"gcs_bucket\": \"your-gcs-bucket\"
}
}
}"
```

View File

@@ -0,0 +1,14 @@
"""
Vertex AI RAG Engine module.
Handles RAG ingestion via Vertex AI RAG Engine API.
"""
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
__all__ = [
"VertexAIRAGIngestion",
"VertexAIRAGTransformation",
]

View File

@@ -0,0 +1,320 @@
"""
Vertex AI RAG Engine Ingestion implementation.
Uses:
- litellm.files.acreate_file for uploading files to GCS
- Vertex AI RAG Engine REST API for importing files into corpus (via httpx)
Key differences from OpenAI:
- Files must be uploaded to GCS first (via litellm.files.acreate_file)
- Embedding is handled internally using text-embedding-005 by default
- Chunking configured via unified chunking_strategy in ingest_options
"""
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm import get_secret_str
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.types.llms.custom_http import httpxSpecialProvider
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGChunkingStrategy, RAGIngestOptions
def _get_str_or_none(value: Any) -> Optional[str]:
"""Cast config value to Optional[str]."""
return str(value) if value is not None else None
def _get_int(value: Any, default: int) -> int:
"""Cast config value to int with default."""
if value is None:
return default
return int(value)
class VertexAIRAGIngestion(BaseRAGIngestion):
"""
Vertex AI RAG Engine ingestion.
Uses litellm.files.acreate_file for GCS upload, then imports into RAG corpus.
Required config in vector_store:
- vector_store_id: RAG corpus ID (required)
Optional config in vector_store:
- vertex_project: GCP project ID (uses env VERTEXAI_PROJECT if not set)
- vertex_location: GCP region (default: us-central1)
- vertex_credentials: Path to credentials JSON (uses ADC if not set)
- wait_for_import: Wait for import to complete (default: True)
- import_timeout: Timeout in seconds (default: 600)
Chunking is configured via ingest_options["chunking_strategy"]:
- chunk_size: Maximum size of chunks (default: 1000)
- chunk_overlap: Overlap between chunks (default: 200)
Authentication:
- Uses Application Default Credentials (ADC)
- Run: gcloud auth application-default login
"""
def __init__(
self,
ingest_options: "RAGIngestOptions",
router: Optional["Router"] = None,
):
super().__init__(ingest_options=ingest_options, router=router)
# Get corpus ID (required for Vertex AI)
self.corpus_id = self.vector_store_config.get("vector_store_id")
if not self.corpus_id:
raise ValueError(
"vector_store_id (corpus ID) is required for Vertex AI RAG ingestion. "
"Please provide an existing RAG corpus ID."
)
# GCP config
self.vertex_project = (
self.vector_store_config.get("vertex_project")
or get_secret_str("VERTEXAI_PROJECT")
)
self.vertex_location = (
self.vector_store_config.get("vertex_location")
or get_secret_str("VERTEXAI_LOCATION")
or "us-central1"
)
self.vertex_credentials = self.vector_store_config.get("vertex_credentials")
# GCS bucket for file uploads
self.gcs_bucket = (
self.vector_store_config.get("gcs_bucket")
or os.environ.get("GCS_BUCKET_NAME")
)
if not self.gcs_bucket:
raise ValueError(
"gcs_bucket is required for Vertex AI RAG ingestion. "
"Set via vector_store config or GCS_BUCKET_NAME env var."
)
# Import settings
self.wait_for_import = self.vector_store_config.get("wait_for_import", True)
self.import_timeout = _get_int(
self.vector_store_config.get("import_timeout"), 600
)
# Validate required config
if not self.vertex_project:
raise ValueError(
"vertex_project is required for Vertex AI RAG ingestion. "
"Set via vector_store config or VERTEXAI_PROJECT env var."
)
def _get_corpus_name(self) -> str:
"""Get full corpus resource name."""
return f"projects/{self.vertex_project}/locations/{self.vertex_location}/ragCorpora/{self.corpus_id}"
async def _upload_file_to_gcs(
self,
file_content: bytes,
filename: str,
content_type: str,
) -> str:
"""
Upload file to GCS using litellm.files.acreate_file.
Returns:
GCS URI of the uploaded file (gs://bucket/path/file)
"""
import litellm
# Set GCS_BUCKET_NAME env var for litellm.files.create_file
# The handler uses this to determine where to upload
original_bucket = os.environ.get("GCS_BUCKET_NAME")
if self.gcs_bucket:
os.environ["GCS_BUCKET_NAME"] = self.gcs_bucket
try:
# Create file tuple for litellm.files.acreate_file
file_tuple = (filename, file_content, content_type)
verbose_logger.debug(
f"Uploading file to GCS via litellm.files.acreate_file: {filename} "
f"(bucket: {self.gcs_bucket})"
)
# Upload to GCS using LiteLLM's file upload
response = await litellm.acreate_file(
file=file_tuple,
purpose="assistants", # Purpose for file storage
custom_llm_provider="vertex_ai",
vertex_project=self.vertex_project,
vertex_location=self.vertex_location,
vertex_credentials=self.vertex_credentials,
)
# The response.id should be the GCS URI
gcs_uri = response.id
verbose_logger.info(f"Uploaded file to GCS: {gcs_uri}")
return gcs_uri
finally:
# Restore original env var
if original_bucket is not None:
os.environ["GCS_BUCKET_NAME"] = original_bucket
elif "GCS_BUCKET_NAME" in os.environ:
del os.environ["GCS_BUCKET_NAME"]
async def _import_file_to_corpus_via_sdk(
self,
gcs_uri: str,
) -> None:
"""
Import file into RAG corpus using the Vertex AI SDK.
The REST API endpoint for importRagFiles is not publicly available,
so we use the Python SDK.
"""
try:
from vertexai import init as vertexai_init
from vertexai import rag # type: ignore[import-not-found]
except ImportError:
raise ImportError(
"vertexai.rag module not found. Vertex AI RAG requires "
"google-cloud-aiplatform>=1.60.0. Install with: "
"pip install 'google-cloud-aiplatform>=1.60.0'"
)
# Initialize Vertex AI
vertexai_init(project=self.vertex_project, location=self.vertex_location)
# Get chunking config from ingest_options (unified interface)
transformation_config = self._build_transformation_config()
corpus_name = self._get_corpus_name()
verbose_logger.debug(f"Importing {gcs_uri} into corpus {self.corpus_id}")
if self.wait_for_import:
# Synchronous import - wait for completion
response = rag.import_files(
corpus_name=corpus_name,
paths=[gcs_uri],
transformation_config=transformation_config,
timeout=self.import_timeout,
)
verbose_logger.info(
f"Import complete: {response.imported_rag_files_count} files imported"
)
else:
# Async import - don't wait
_ = rag.import_files_async(
corpus_name=corpus_name,
paths=[gcs_uri],
transformation_config=transformation_config,
)
verbose_logger.info("Import started asynchronously")
def _build_transformation_config(self) -> Any:
"""
Build Vertex AI TransformationConfig from unified chunking_strategy.
Uses chunking_strategy from ingest_options (not vector_store).
"""
try:
from vertexai import rag # type: ignore[import-not-found]
except ImportError:
raise ImportError(
"vertexai.rag module not found. Vertex AI RAG requires "
"google-cloud-aiplatform>=1.60.0. Install with: "
"pip install 'google-cloud-aiplatform>=1.60.0'"
)
# Get chunking config from ingest_options using transformation class
from typing import cast
from litellm.types.rag import RAGChunkingStrategy
transformation = VertexAIRAGTransformation()
chunking_config = transformation.transform_chunking_strategy_to_vertex_format(
cast(Optional[RAGChunkingStrategy], self.chunking_strategy)
)
chunk_size = chunking_config["chunking_config"]["chunk_size"]
chunk_overlap = chunking_config["chunking_config"]["chunk_overlap"]
return rag.TransformationConfig(
chunking_config=rag.ChunkingConfig(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
),
)
async def embed(
self,
chunks: List[str],
) -> Optional[List[List[float]]]:
"""
Vertex AI handles embedding internally - skip this step.
Returns:
None (Vertex AI embeds when files are imported)
"""
return None
async def store(
self,
file_content: Optional[bytes],
filename: Optional[str],
content_type: Optional[str],
chunks: List[str],
embeddings: Optional[List[List[float]]],
) -> Tuple[Optional[str], Optional[str]]:
"""
Store content in Vertex AI RAG corpus.
Vertex AI workflow:
1. Upload file to GCS via litellm.files.acreate_file
2. Import file into RAG corpus via SDK
3. (Optional) Wait for import to complete
Args:
file_content: Raw file bytes
filename: Name of the file
content_type: MIME type
chunks: Ignored - Vertex AI handles chunking
embeddings: Ignored - Vertex AI handles embedding
Returns:
Tuple of (corpus_id, gcs_uri)
"""
if not file_content or not filename:
verbose_logger.warning(
"No file content or filename provided for Vertex AI ingestion"
)
return _get_str_or_none(self.corpus_id), None
# Step 1: Upload file to GCS
gcs_uri = await self._upload_file_to_gcs(
file_content=file_content,
filename=filename,
content_type=content_type or "application/octet-stream",
)
# Step 2: Import file into RAG corpus
try:
await self._import_file_to_corpus_via_sdk(gcs_uri=gcs_uri)
except Exception as e:
verbose_logger.error(f"Failed to import file into RAG corpus: {e}")
raise RuntimeError(f"Failed to import file into RAG corpus: {e}") from e
return str(self.corpus_id), gcs_uri

View File

@@ -0,0 +1,155 @@
"""
Transformation utilities for Vertex AI RAG Engine.
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
"""
from typing import Any, Dict, Optional, Tuple
from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.rag import RAGChunkingStrategy
class VertexAIRAGTransformation(VertexBase):
"""
Transformation class for Vertex AI RAG Engine API.
Handles:
- Converting unified chunking_strategy to Vertex AI format
- Building import request payloads
- Transforming responses
"""
def __init__(self):
super().__init__()
def get_import_rag_files_url(
self,
vertex_project: str,
vertex_location: str,
corpus_id: str,
) -> str:
"""
Get the URL for importing RAG files.
Note: The REST endpoint for importRagFiles may not be publicly available.
Vertex AI RAG Engine primarily uses gRPC-based SDK.
"""
base_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1"
return f"{base_url}/projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{corpus_id}:importRagFiles"
def get_retrieve_contexts_url(
self,
vertex_project: str,
vertex_location: str,
) -> str:
"""Get the URL for retrieving contexts (search)."""
base_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1"
return f"{base_url}/projects/{vertex_project}/locations/{vertex_location}:retrieveContexts"
def transform_chunking_strategy_to_vertex_format(
self,
chunking_strategy: Optional[RAGChunkingStrategy],
) -> Dict[str, Any]:
"""
Transform LiteLLM's unified chunking_strategy to Vertex AI RAG format.
LiteLLM format (RAGChunkingStrategy):
{
"chunk_size": 1000,
"chunk_overlap": 200,
"separators": ["\n\n", "\n", " ", ""]
}
Vertex AI RAG format (TransformationConfig):
{
"chunking_config": {
"chunk_size": 1000,
"chunk_overlap": 200
}
}
Note: Vertex AI doesn't support custom separators in the same way,
so we only transform chunk_size and chunk_overlap.
"""
if not chunking_strategy:
return {
"chunking_config": {
"chunk_size": DEFAULT_CHUNK_SIZE,
"chunk_overlap": DEFAULT_CHUNK_OVERLAP,
}
}
chunk_size = chunking_strategy.get("chunk_size", DEFAULT_CHUNK_SIZE)
chunk_overlap = chunking_strategy.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
# Log if separators are provided (not supported by Vertex AI)
if chunking_strategy.get("separators"):
verbose_logger.warning(
"Vertex AI RAG Engine does not support custom separators. "
"The 'separators' parameter will be ignored."
)
return {
"chunking_config": {
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
}
}
def build_import_rag_files_request(
self,
gcs_uri: str,
chunking_strategy: Optional[RAGChunkingStrategy] = None,
) -> Dict[str, Any]:
"""
Build the request payload for importing RAG files.
Args:
gcs_uri: GCS URI of the file to import (e.g., gs://bucket/path/file.txt)
chunking_strategy: LiteLLM unified chunking config
Returns:
Request payload dict for importRagFiles API
"""
transformation_config = self.transform_chunking_strategy_to_vertex_format(
chunking_strategy
)
return {
"import_rag_files_config": {
"gcs_source": {
"uris": [gcs_uri]
},
"rag_file_transformation_config": transformation_config,
}
}
def get_auth_headers(
self,
vertex_credentials: Optional[str] = None,
vertex_project: Optional[str] = None,
) -> Dict[str, str]:
"""
Get authentication headers for Vertex AI API calls.
Uses the base class method to get credentials.
"""
credentials = self.get_vertex_ai_credentials(
{"vertex_credentials": vertex_credentials}
)
project = vertex_project or self.get_vertex_ai_project({})
access_token, _ = self._ensure_access_token(
credentials=credentials,
project_id=project,
custom_llm_provider="vertex_ai",
)
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}

View File

@@ -12,7 +12,7 @@ __all__ = ["ingest", "aingest"]
import asyncio
import contextvars
from functools import partial
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
import httpx
@@ -84,7 +84,7 @@ async def _execute_ingest_pipeline(
provider = vector_store_config.get("custom_llm_provider", "openai")
# Get provider-specific ingestion class
ingestion_class = get_ingestion_class(provider)
ingestion_class = get_rag_ingestion_class(provider)
# Create ingestion instance
ingestion = ingestion_class(

65
litellm/rag/utils.py Normal file
View File

@@ -0,0 +1,65 @@
"""
RAG utility functions.
Provides provider configuration utilities similar to ProviderConfigManager.
"""
from typing import TYPE_CHECKING, Optional, Type
if TYPE_CHECKING:
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
def get_rag_ingestion_class(custom_llm_provider: str) -> Type["BaseRAGIngestion"]:
"""
Get the appropriate RAG ingestion class for a provider.
Args:
custom_llm_provider: The LLM provider name (e.g., "openai", "bedrock", "vertex_ai")
Returns:
The ingestion class for the provider
Raises:
ValueError: If the provider is not supported
"""
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
provider_map = {
"openai": OpenAIRAGIngestion,
"bedrock": BedrockRAGIngestion,
"vertex_ai": VertexAIRAGIngestion,
}
ingestion_class = provider_map.get(custom_llm_provider)
if ingestion_class is None:
raise ValueError(
f"RAG ingestion not supported for provider: {custom_llm_provider}. "
f"Supported providers: {list(provider_map.keys())}"
)
return ingestion_class
def get_rag_transformation_class(custom_llm_provider: str):
"""
Get the appropriate RAG transformation class for a provider.
Args:
custom_llm_provider: The LLM provider name
Returns:
The transformation class for the provider, or None if not needed
"""
if custom_llm_provider == "vertex_ai":
from litellm.llms.vertex_ai.rag_engine.transformation import (
VertexAIRAGTransformation,
)
return VertexAIRAGTransformation
# OpenAI and Bedrock don't need special transformations
return None

View File

@@ -86,8 +86,37 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
aws_external_id: Optional[str]
class VertexAIVectorStoreOptions(TypedDict, total=False):
"""
Vertex AI RAG Engine configuration.
Example (use existing corpus):
{"custom_llm_provider": "vertex_ai", "vector_store_id": "CORPUS_ID", "gcs_bucket": "my-bucket"}
Requires:
- gcloud auth application-default login (for ADC authentication)
- Files are uploaded to GCS via litellm.files.create_file, then imported into RAG corpus
- GCS bucket must be provided via gcs_bucket or GCS_BUCKET_NAME env var
"""
custom_llm_provider: Literal["vertex_ai"]
vector_store_id: str # RAG corpus ID (required for Vertex AI)
# GCP config
vertex_project: Optional[str] # GCP project ID (uses env VERTEXAI_PROJECT if not set)
vertex_location: Optional[str] # GCP region (default: us-central1)
vertex_credentials: Optional[str] # Path to credentials JSON (uses ADC if not set)
gcs_bucket: Optional[str] # GCS bucket for file uploads (uses env GCS_BUCKET_NAME if not set)
# Import settings
wait_for_import: Optional[bool] # Wait for import to complete (default: True)
import_timeout: Optional[int] # Timeout in seconds (default: 600)
# Union type for vector store options
RAGIngestVectorStoreOptions = Union[OpenAIVectorStoreOptions, BedrockVectorStoreOptions]
RAGIngestVectorStoreOptions = Union[
OpenAIVectorStoreOptions, BedrockVectorStoreOptions, VertexAIVectorStoreOptions
]
class RAGIngestOptions(TypedDict, total=False):

View File

@@ -0,0 +1,130 @@
"""
Vertex AI RAG Engine ingestion tests.
Requires:
- gcloud auth application-default login (for ADC authentication)
Environment variables:
- VERTEX_PROJECT: GCP project ID (required)
- VERTEX_LOCATION: GCP region (optional, defaults to europe-west1)
- VERTEX_CORPUS_ID: Existing RAG corpus ID (required for Vertex AI)
- GCS_BUCKET_NAME: GCS bucket for file uploads (required)
"""
import os
import sys
from typing import Any, Dict, Optional
import pytest
sys.path.insert(0, os.path.abspath("../../.."))
import litellm
from litellm.types.rag import RAGIngestOptions
from tests.vector_store_tests.rag.base_rag_tests import BaseRAGTest
class TestRAGVertexAI(BaseRAGTest):
"""Test RAG Ingest with Vertex AI RAG Engine."""
@pytest.fixture(autouse=True)
def check_env_vars(self):
"""Check required environment variables before each test."""
vertex_project = os.environ.get("VERTEX_PROJECT")
corpus_id = os.environ.get("VERTEX_CORPUS_ID")
gcs_bucket = os.environ.get("GCS_BUCKET_NAME")
if not vertex_project:
pytest.skip("Skipping Vertex AI test: VERTEX_PROJECT required")
if not corpus_id:
pytest.skip("Skipping Vertex AI test: VERTEX_CORPUS_ID required")
if not gcs_bucket:
pytest.skip("Skipping Vertex AI test: GCS_BUCKET_NAME required")
# Check if vertexai is installed
try:
from vertexai import rag
except ImportError:
pytest.skip("Skipping Vertex AI test: google-cloud-aiplatform>=1.60.0 required")
def get_base_ingest_options(self) -> RAGIngestOptions:
"""
Return Vertex AI-specific ingest options.
Chunking is configured via chunking_strategy (unified interface),
not inside vector_store.
"""
corpus_id = os.environ.get("VERTEX_CORPUS_ID")
vertex_project = os.environ.get("VERTEX_PROJECT")
vertex_location = os.environ.get("VERTEX_LOCATION", "europe-west1")
gcs_bucket = os.environ.get("GCS_BUCKET_NAME")
return {
"chunking_strategy": {
"chunk_size": 512,
"chunk_overlap": 100,
},
"vector_store": {
"custom_llm_provider": "vertex_ai",
"vertex_project": vertex_project,
"vertex_location": vertex_location,
"vector_store_id": corpus_id,
"gcs_bucket": gcs_bucket,
"wait_for_import": True,
},
}
async def query_vector_store(
self,
vector_store_id: str,
query: str,
) -> Optional[Dict[str, Any]]:
"""Query Vertex AI RAG corpus."""
try:
from vertexai import init as vertexai_init
from vertexai import rag
except ImportError:
pytest.skip("vertexai required for Vertex AI tests")
vertex_project = os.environ.get("VERTEX_PROJECT")
vertex_location = os.environ.get("VERTEX_LOCATION", "europe-west1")
# Initialize Vertex AI
vertexai_init(project=vertex_project, location=vertex_location)
# Build corpus name
corpus_name = f"projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{vector_store_id}"
# Query the corpus
response = rag.retrieval_query(
rag_resources=[
rag.RagResource(rag_corpus=corpus_name)
],
text=query,
rag_retrieval_config=rag.RagRetrievalConfig(
top_k=5,
),
)
if hasattr(response, 'contexts') and response.contexts.contexts:
# Convert to dict format
results = []
for ctx in response.contexts.contexts:
results.append({
"text": ctx.text,
"score": ctx.score,
"source_uri": ctx.source_uri,
})
# Check if query terms appear in results
for result in results:
if query.lower() in result["text"].lower():
return {"results": results}
# Return results even if exact match not found
return {"results": results}
return None