mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Feat] LiteLLM RAG API - Add support for Vertex RAG engine (#17117)
* add VertexAIVectorStoreOptions
* Revert "add VertexAIVectorStoreOptions"
This reverts commit b086adf10b.
* add VertexAIVectorStoreOptions
* add get_rag_ingestion_class
* add VertexAIRAGTransformation
* test ingestion
* docs vertex ai rag engine
This commit is contained in:
@@ -4,9 +4,8 @@ All-in-one document ingestion pipeline: **Upload → Chunk → Embed → Vector
|
||||
|
||||
| Feature | Supported |
|
||||
|---------|-----------|
|
||||
| Cost Tracking | ❌ |
|
||||
| Logging | ✅ |
|
||||
| Supported Providers | `openai`, `bedrock`, `gemini` |
|
||||
| Supported Providers | `openai`, `bedrock`, `vertex_ai`, `gemini` |
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -50,9 +49,9 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
}"
|
||||
```
|
||||
|
||||
### Gemini
|
||||
### Vertex AI RAG Engine
|
||||
|
||||
```bash showLineNumbers title="Ingest to Gemini File Search"
|
||||
```bash showLineNumbers title="Ingest to Vertex AI RAG Corpus"
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
@@ -64,38 +63,14 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
},
|
||||
\"ingest_options\": {
|
||||
\"vector_store\": {
|
||||
\"custom_llm_provider\": \"gemini\"
|
||||
\"custom_llm_provider\": \"vertex_ai\",
|
||||
\"vector_store_id\": \"your-corpus-id\",
|
||||
\"gcs_bucket\": \"your-gcs-bucket\"
|
||||
}
|
||||
}
|
||||
}"
|
||||
```
|
||||
|
||||
**With Custom Chunking:**
|
||||
|
||||
```bash showLineNumbers title="Ingest with custom chunking"
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"file": {
|
||||
"filename": "document.txt",
|
||||
"content": "'$(base64 -i document.txt)'",
|
||||
"content_type": "text/plain"
|
||||
},
|
||||
"ingest_options": {
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "gemini"
|
||||
},
|
||||
"chunking_strategy": {
|
||||
"white_space_config": {
|
||||
"max_tokens_per_chunk": 200,
|
||||
"max_overlap_tokens": 20
|
||||
}
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Response
|
||||
|
||||
```json
|
||||
@@ -242,6 +217,26 @@ When `vector_store_id` is omitted, LiteLLM automatically creates:
|
||||
- Data Source
|
||||
:::
|
||||
|
||||
### vector_store (Vertex AI)
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `custom_llm_provider` | string | - | `"vertex_ai"` |
|
||||
| `vector_store_id` | string | **required** | RAG corpus ID |
|
||||
| `gcs_bucket` | string | **required** | GCS bucket for file uploads |
|
||||
| `vertex_project` | string | env `VERTEXAI_PROJECT` | GCP project ID |
|
||||
| `vertex_location` | string | `us-central1` | GCP region |
|
||||
| `vertex_credentials` | string | ADC | Path to credentials JSON |
|
||||
| `wait_for_import` | boolean | `true` | Wait for import to complete |
|
||||
| `import_timeout` | integer | `600` | Timeout in seconds (if waiting) |
|
||||
|
||||
:::info Vertex AI Prerequisites
|
||||
1. Create a RAG corpus in Vertex AI console or via API
|
||||
2. Create a GCS bucket for file uploads
|
||||
3. Authenticate via `gcloud auth application-default login`
|
||||
4. Install: `pip install 'google-cloud-aiplatform>=1.60.0'`
|
||||
:::
|
||||
|
||||
## Input Examples
|
||||
|
||||
### File (Base64)
|
||||
@@ -271,3 +266,40 @@ curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
}'
|
||||
```
|
||||
|
||||
## Chunking Strategy
|
||||
|
||||
Control how documents are split into chunks before embedding. Specify `chunking_strategy` in `ingest_options`.
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `chunk_size` | integer | `1000` | Maximum size of each chunk |
|
||||
| `chunk_overlap` | integer | `200` | Overlap between consecutive chunks |
|
||||
|
||||
### Vertex AI RAG Engine
|
||||
|
||||
Vertex AI RAG Engine supports custom chunking via the `chunking_strategy` parameter. Chunks are processed server-side during import.
|
||||
|
||||
```bash showLineNumbers title="Vertex AI with custom chunking"
|
||||
curl -X POST "http://localhost:4000/v1/rag/ingest" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"file\": {
|
||||
\"filename\": \"document.txt\",
|
||||
\"content\": \"$(base64 -i document.txt)\",
|
||||
\"content_type\": \"text/plain\"
|
||||
},
|
||||
\"ingest_options\": {
|
||||
\"chunking_strategy\": {
|
||||
\"chunk_size\": 500,
|
||||
\"chunk_overlap\": 100
|
||||
},
|
||||
\"vector_store\": {
|
||||
\"custom_llm_provider\": \"vertex_ai\",
|
||||
\"vector_store_id\": \"your-corpus-id\",
|
||||
\"gcs_bucket\": \"your-gcs-bucket\"
|
||||
}
|
||||
}
|
||||
}"
|
||||
```
|
||||
|
||||
|
||||
14
litellm/llms/vertex_ai/rag_engine/__init__.py
Normal file
14
litellm/llms/vertex_ai/rag_engine/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Vertex AI RAG Engine module.
|
||||
|
||||
Handles RAG ingestion via Vertex AI RAG Engine API.
|
||||
"""
|
||||
|
||||
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
|
||||
|
||||
__all__ = [
|
||||
"VertexAIRAGIngestion",
|
||||
"VertexAIRAGTransformation",
|
||||
]
|
||||
|
||||
320
litellm/llms/vertex_ai/rag_engine/ingestion.py
Normal file
320
litellm/llms/vertex_ai/rag_engine/ingestion.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Vertex AI RAG Engine Ingestion implementation.
|
||||
|
||||
Uses:
|
||||
- litellm.files.acreate_file for uploading files to GCS
|
||||
- Vertex AI RAG Engine REST API for importing files into corpus (via httpx)
|
||||
|
||||
Key differences from OpenAI:
|
||||
- Files must be uploaded to GCS first (via litellm.files.acreate_file)
|
||||
- Embedding is handled internally using text-embedding-005 by default
|
||||
- Chunking configured via unified chunking_strategy in ingest_options
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm import get_secret_str
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGChunkingStrategy, RAGIngestOptions
|
||||
|
||||
|
||||
def _get_str_or_none(value: Any) -> Optional[str]:
|
||||
"""Cast config value to Optional[str]."""
|
||||
return str(value) if value is not None else None
|
||||
|
||||
|
||||
def _get_int(value: Any, default: int) -> int:
|
||||
"""Cast config value to int with default."""
|
||||
if value is None:
|
||||
return default
|
||||
return int(value)
|
||||
|
||||
|
||||
class VertexAIRAGIngestion(BaseRAGIngestion):
|
||||
"""
|
||||
Vertex AI RAG Engine ingestion.
|
||||
|
||||
Uses litellm.files.acreate_file for GCS upload, then imports into RAG corpus.
|
||||
|
||||
Required config in vector_store:
|
||||
- vector_store_id: RAG corpus ID (required)
|
||||
|
||||
Optional config in vector_store:
|
||||
- vertex_project: GCP project ID (uses env VERTEXAI_PROJECT if not set)
|
||||
- vertex_location: GCP region (default: us-central1)
|
||||
- vertex_credentials: Path to credentials JSON (uses ADC if not set)
|
||||
- wait_for_import: Wait for import to complete (default: True)
|
||||
- import_timeout: Timeout in seconds (default: 600)
|
||||
|
||||
Chunking is configured via ingest_options["chunking_strategy"]:
|
||||
- chunk_size: Maximum size of chunks (default: 1000)
|
||||
- chunk_overlap: Overlap between chunks (default: 200)
|
||||
|
||||
Authentication:
|
||||
- Uses Application Default Credentials (ADC)
|
||||
- Run: gcloud auth application-default login
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ingest_options: "RAGIngestOptions",
|
||||
router: Optional["Router"] = None,
|
||||
):
|
||||
super().__init__(ingest_options=ingest_options, router=router)
|
||||
|
||||
# Get corpus ID (required for Vertex AI)
|
||||
self.corpus_id = self.vector_store_config.get("vector_store_id")
|
||||
if not self.corpus_id:
|
||||
raise ValueError(
|
||||
"vector_store_id (corpus ID) is required for Vertex AI RAG ingestion. "
|
||||
"Please provide an existing RAG corpus ID."
|
||||
)
|
||||
|
||||
# GCP config
|
||||
self.vertex_project = (
|
||||
self.vector_store_config.get("vertex_project")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
self.vertex_location = (
|
||||
self.vector_store_config.get("vertex_location")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
or "us-central1"
|
||||
)
|
||||
self.vertex_credentials = self.vector_store_config.get("vertex_credentials")
|
||||
|
||||
# GCS bucket for file uploads
|
||||
self.gcs_bucket = (
|
||||
self.vector_store_config.get("gcs_bucket")
|
||||
or os.environ.get("GCS_BUCKET_NAME")
|
||||
)
|
||||
if not self.gcs_bucket:
|
||||
raise ValueError(
|
||||
"gcs_bucket is required for Vertex AI RAG ingestion. "
|
||||
"Set via vector_store config or GCS_BUCKET_NAME env var."
|
||||
)
|
||||
|
||||
# Import settings
|
||||
self.wait_for_import = self.vector_store_config.get("wait_for_import", True)
|
||||
self.import_timeout = _get_int(
|
||||
self.vector_store_config.get("import_timeout"), 600
|
||||
)
|
||||
|
||||
# Validate required config
|
||||
if not self.vertex_project:
|
||||
raise ValueError(
|
||||
"vertex_project is required for Vertex AI RAG ingestion. "
|
||||
"Set via vector_store config or VERTEXAI_PROJECT env var."
|
||||
)
|
||||
|
||||
def _get_corpus_name(self) -> str:
|
||||
"""Get full corpus resource name."""
|
||||
return f"projects/{self.vertex_project}/locations/{self.vertex_location}/ragCorpora/{self.corpus_id}"
|
||||
|
||||
async def _upload_file_to_gcs(
|
||||
self,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
) -> str:
|
||||
"""
|
||||
Upload file to GCS using litellm.files.acreate_file.
|
||||
|
||||
Returns:
|
||||
GCS URI of the uploaded file (gs://bucket/path/file)
|
||||
"""
|
||||
import litellm
|
||||
|
||||
# Set GCS_BUCKET_NAME env var for litellm.files.create_file
|
||||
# The handler uses this to determine where to upload
|
||||
original_bucket = os.environ.get("GCS_BUCKET_NAME")
|
||||
if self.gcs_bucket:
|
||||
os.environ["GCS_BUCKET_NAME"] = self.gcs_bucket
|
||||
|
||||
try:
|
||||
# Create file tuple for litellm.files.acreate_file
|
||||
file_tuple = (filename, file_content, content_type)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Uploading file to GCS via litellm.files.acreate_file: {filename} "
|
||||
f"(bucket: {self.gcs_bucket})"
|
||||
)
|
||||
|
||||
# Upload to GCS using LiteLLM's file upload
|
||||
response = await litellm.acreate_file(
|
||||
file=file_tuple,
|
||||
purpose="assistants", # Purpose for file storage
|
||||
custom_llm_provider="vertex_ai",
|
||||
vertex_project=self.vertex_project,
|
||||
vertex_location=self.vertex_location,
|
||||
vertex_credentials=self.vertex_credentials,
|
||||
)
|
||||
|
||||
# The response.id should be the GCS URI
|
||||
gcs_uri = response.id
|
||||
verbose_logger.info(f"Uploaded file to GCS: {gcs_uri}")
|
||||
|
||||
return gcs_uri
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_bucket is not None:
|
||||
os.environ["GCS_BUCKET_NAME"] = original_bucket
|
||||
elif "GCS_BUCKET_NAME" in os.environ:
|
||||
del os.environ["GCS_BUCKET_NAME"]
|
||||
|
||||
async def _import_file_to_corpus_via_sdk(
|
||||
self,
|
||||
gcs_uri: str,
|
||||
) -> None:
|
||||
"""
|
||||
Import file into RAG corpus using the Vertex AI SDK.
|
||||
|
||||
The REST API endpoint for importRagFiles is not publicly available,
|
||||
so we use the Python SDK.
|
||||
"""
|
||||
try:
|
||||
from vertexai import init as vertexai_init
|
||||
from vertexai import rag # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"vertexai.rag module not found. Vertex AI RAG requires "
|
||||
"google-cloud-aiplatform>=1.60.0. Install with: "
|
||||
"pip install 'google-cloud-aiplatform>=1.60.0'"
|
||||
)
|
||||
|
||||
# Initialize Vertex AI
|
||||
vertexai_init(project=self.vertex_project, location=self.vertex_location)
|
||||
|
||||
# Get chunking config from ingest_options (unified interface)
|
||||
transformation_config = self._build_transformation_config()
|
||||
|
||||
corpus_name = self._get_corpus_name()
|
||||
verbose_logger.debug(f"Importing {gcs_uri} into corpus {self.corpus_id}")
|
||||
|
||||
if self.wait_for_import:
|
||||
# Synchronous import - wait for completion
|
||||
response = rag.import_files(
|
||||
corpus_name=corpus_name,
|
||||
paths=[gcs_uri],
|
||||
transformation_config=transformation_config,
|
||||
timeout=self.import_timeout,
|
||||
)
|
||||
verbose_logger.info(
|
||||
f"Import complete: {response.imported_rag_files_count} files imported"
|
||||
)
|
||||
else:
|
||||
# Async import - don't wait
|
||||
_ = rag.import_files_async(
|
||||
corpus_name=corpus_name,
|
||||
paths=[gcs_uri],
|
||||
transformation_config=transformation_config,
|
||||
)
|
||||
verbose_logger.info("Import started asynchronously")
|
||||
|
||||
def _build_transformation_config(self) -> Any:
|
||||
"""
|
||||
Build Vertex AI TransformationConfig from unified chunking_strategy.
|
||||
|
||||
Uses chunking_strategy from ingest_options (not vector_store).
|
||||
"""
|
||||
try:
|
||||
from vertexai import rag # type: ignore[import-not-found]
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"vertexai.rag module not found. Vertex AI RAG requires "
|
||||
"google-cloud-aiplatform>=1.60.0. Install with: "
|
||||
"pip install 'google-cloud-aiplatform>=1.60.0'"
|
||||
)
|
||||
|
||||
# Get chunking config from ingest_options using transformation class
|
||||
from typing import cast
|
||||
|
||||
from litellm.types.rag import RAGChunkingStrategy
|
||||
|
||||
transformation = VertexAIRAGTransformation()
|
||||
chunking_config = transformation.transform_chunking_strategy_to_vertex_format(
|
||||
cast(Optional[RAGChunkingStrategy], self.chunking_strategy)
|
||||
)
|
||||
|
||||
chunk_size = chunking_config["chunking_config"]["chunk_size"]
|
||||
chunk_overlap = chunking_config["chunking_config"]["chunk_overlap"]
|
||||
|
||||
return rag.TransformationConfig(
|
||||
chunking_config=rag.ChunkingConfig(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
),
|
||||
)
|
||||
|
||||
async def embed(
|
||||
self,
|
||||
chunks: List[str],
|
||||
) -> Optional[List[List[float]]]:
|
||||
"""
|
||||
Vertex AI handles embedding internally - skip this step.
|
||||
|
||||
Returns:
|
||||
None (Vertex AI embeds when files are imported)
|
||||
"""
|
||||
return None
|
||||
|
||||
async def store(
|
||||
self,
|
||||
file_content: Optional[bytes],
|
||||
filename: Optional[str],
|
||||
content_type: Optional[str],
|
||||
chunks: List[str],
|
||||
embeddings: Optional[List[List[float]]],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Store content in Vertex AI RAG corpus.
|
||||
|
||||
Vertex AI workflow:
|
||||
1. Upload file to GCS via litellm.files.acreate_file
|
||||
2. Import file into RAG corpus via SDK
|
||||
3. (Optional) Wait for import to complete
|
||||
|
||||
Args:
|
||||
file_content: Raw file bytes
|
||||
filename: Name of the file
|
||||
content_type: MIME type
|
||||
chunks: Ignored - Vertex AI handles chunking
|
||||
embeddings: Ignored - Vertex AI handles embedding
|
||||
|
||||
Returns:
|
||||
Tuple of (corpus_id, gcs_uri)
|
||||
"""
|
||||
if not file_content or not filename:
|
||||
verbose_logger.warning(
|
||||
"No file content or filename provided for Vertex AI ingestion"
|
||||
)
|
||||
return _get_str_or_none(self.corpus_id), None
|
||||
|
||||
# Step 1: Upload file to GCS
|
||||
gcs_uri = await self._upload_file_to_gcs(
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
content_type=content_type or "application/octet-stream",
|
||||
)
|
||||
|
||||
# Step 2: Import file into RAG corpus
|
||||
try:
|
||||
await self._import_file_to_corpus_via_sdk(gcs_uri=gcs_uri)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to import file into RAG corpus: {e}")
|
||||
raise RuntimeError(f"Failed to import file into RAG corpus: {e}") from e
|
||||
|
||||
return str(self.corpus_id), gcs_uri
|
||||
|
||||
155
litellm/llms/vertex_ai/rag_engine/transformation.py
Normal file
155
litellm/llms/vertex_ai/rag_engine/transformation.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""
|
||||
Transformation utilities for Vertex AI RAG Engine.
|
||||
|
||||
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.rag import RAGChunkingStrategy
|
||||
|
||||
|
||||
class VertexAIRAGTransformation(VertexBase):
|
||||
"""
|
||||
Transformation class for Vertex AI RAG Engine API.
|
||||
|
||||
Handles:
|
||||
- Converting unified chunking_strategy to Vertex AI format
|
||||
- Building import request payloads
|
||||
- Transforming responses
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_import_rag_files_url(
|
||||
self,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
corpus_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Get the URL for importing RAG files.
|
||||
|
||||
Note: The REST endpoint for importRagFiles may not be publicly available.
|
||||
Vertex AI RAG Engine primarily uses gRPC-based SDK.
|
||||
"""
|
||||
base_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1"
|
||||
return f"{base_url}/projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{corpus_id}:importRagFiles"
|
||||
|
||||
def get_retrieve_contexts_url(
|
||||
self,
|
||||
vertex_project: str,
|
||||
vertex_location: str,
|
||||
) -> str:
|
||||
"""Get the URL for retrieving contexts (search)."""
|
||||
base_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1"
|
||||
return f"{base_url}/projects/{vertex_project}/locations/{vertex_location}:retrieveContexts"
|
||||
|
||||
def transform_chunking_strategy_to_vertex_format(
|
||||
self,
|
||||
chunking_strategy: Optional[RAGChunkingStrategy],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform LiteLLM's unified chunking_strategy to Vertex AI RAG format.
|
||||
|
||||
LiteLLM format (RAGChunkingStrategy):
|
||||
{
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 200,
|
||||
"separators": ["\n\n", "\n", " ", ""]
|
||||
}
|
||||
|
||||
Vertex AI RAG format (TransformationConfig):
|
||||
{
|
||||
"chunking_config": {
|
||||
"chunk_size": 1000,
|
||||
"chunk_overlap": 200
|
||||
}
|
||||
}
|
||||
|
||||
Note: Vertex AI doesn't support custom separators in the same way,
|
||||
so we only transform chunk_size and chunk_overlap.
|
||||
"""
|
||||
if not chunking_strategy:
|
||||
return {
|
||||
"chunking_config": {
|
||||
"chunk_size": DEFAULT_CHUNK_SIZE,
|
||||
"chunk_overlap": DEFAULT_CHUNK_OVERLAP,
|
||||
}
|
||||
}
|
||||
|
||||
chunk_size = chunking_strategy.get("chunk_size", DEFAULT_CHUNK_SIZE)
|
||||
chunk_overlap = chunking_strategy.get("chunk_overlap", DEFAULT_CHUNK_OVERLAP)
|
||||
|
||||
# Log if separators are provided (not supported by Vertex AI)
|
||||
if chunking_strategy.get("separators"):
|
||||
verbose_logger.warning(
|
||||
"Vertex AI RAG Engine does not support custom separators. "
|
||||
"The 'separators' parameter will be ignored."
|
||||
)
|
||||
|
||||
return {
|
||||
"chunking_config": {
|
||||
"chunk_size": chunk_size,
|
||||
"chunk_overlap": chunk_overlap,
|
||||
}
|
||||
}
|
||||
|
||||
def build_import_rag_files_request(
|
||||
self,
|
||||
gcs_uri: str,
|
||||
chunking_strategy: Optional[RAGChunkingStrategy] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build the request payload for importing RAG files.
|
||||
|
||||
Args:
|
||||
gcs_uri: GCS URI of the file to import (e.g., gs://bucket/path/file.txt)
|
||||
chunking_strategy: LiteLLM unified chunking config
|
||||
|
||||
Returns:
|
||||
Request payload dict for importRagFiles API
|
||||
"""
|
||||
transformation_config = self.transform_chunking_strategy_to_vertex_format(
|
||||
chunking_strategy
|
||||
)
|
||||
|
||||
return {
|
||||
"import_rag_files_config": {
|
||||
"gcs_source": {
|
||||
"uris": [gcs_uri]
|
||||
},
|
||||
"rag_file_transformation_config": transformation_config,
|
||||
}
|
||||
}
|
||||
|
||||
def get_auth_headers(
|
||||
self,
|
||||
vertex_credentials: Optional[str] = None,
|
||||
vertex_project: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get authentication headers for Vertex AI API calls.
|
||||
|
||||
Uses the base class method to get credentials.
|
||||
"""
|
||||
credentials = self.get_vertex_ai_credentials(
|
||||
{"vertex_credentials": vertex_credentials}
|
||||
)
|
||||
project = vertex_project or self.get_vertex_ai_project({})
|
||||
|
||||
access_token, _ = self._ensure_access_token(
|
||||
credentials=credentials,
|
||||
project_id=project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ __all__ = ["ingest", "aingest"]
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -84,7 +84,7 @@ async def _execute_ingest_pipeline(
|
||||
provider = vector_store_config.get("custom_llm_provider", "openai")
|
||||
|
||||
# Get provider-specific ingestion class
|
||||
ingestion_class = get_ingestion_class(provider)
|
||||
ingestion_class = get_rag_ingestion_class(provider)
|
||||
|
||||
# Create ingestion instance
|
||||
ingestion = ingestion_class(
|
||||
|
||||
65
litellm/rag/utils.py
Normal file
65
litellm/rag/utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
RAG utility functions.
|
||||
|
||||
Provides provider configuration utilities similar to ProviderConfigManager.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
|
||||
def get_rag_ingestion_class(custom_llm_provider: str) -> Type["BaseRAGIngestion"]:
|
||||
"""
|
||||
Get the appropriate RAG ingestion class for a provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The LLM provider name (e.g., "openai", "bedrock", "vertex_ai")
|
||||
|
||||
Returns:
|
||||
The ingestion class for the provider
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported
|
||||
"""
|
||||
from litellm.llms.vertex_ai.rag_engine.ingestion import VertexAIRAGIngestion
|
||||
from litellm.rag.ingestion.bedrock_ingestion import BedrockRAGIngestion
|
||||
from litellm.rag.ingestion.openai_ingestion import OpenAIRAGIngestion
|
||||
|
||||
provider_map = {
|
||||
"openai": OpenAIRAGIngestion,
|
||||
"bedrock": BedrockRAGIngestion,
|
||||
"vertex_ai": VertexAIRAGIngestion,
|
||||
}
|
||||
|
||||
ingestion_class = provider_map.get(custom_llm_provider)
|
||||
if ingestion_class is None:
|
||||
raise ValueError(
|
||||
f"RAG ingestion not supported for provider: {custom_llm_provider}. "
|
||||
f"Supported providers: {list(provider_map.keys())}"
|
||||
)
|
||||
|
||||
return ingestion_class
|
||||
|
||||
|
||||
def get_rag_transformation_class(custom_llm_provider: str):
|
||||
"""
|
||||
Get the appropriate RAG transformation class for a provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The LLM provider name
|
||||
|
||||
Returns:
|
||||
The transformation class for the provider, or None if not needed
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import (
|
||||
VertexAIRAGTransformation,
|
||||
)
|
||||
|
||||
return VertexAIRAGTransformation
|
||||
|
||||
# OpenAI and Bedrock don't need special transformations
|
||||
return None
|
||||
|
||||
@@ -86,8 +86,37 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
|
||||
aws_external_id: Optional[str]
|
||||
|
||||
|
||||
class VertexAIVectorStoreOptions(TypedDict, total=False):
|
||||
"""
|
||||
Vertex AI RAG Engine configuration.
|
||||
|
||||
Example (use existing corpus):
|
||||
{"custom_llm_provider": "vertex_ai", "vector_store_id": "CORPUS_ID", "gcs_bucket": "my-bucket"}
|
||||
|
||||
Requires:
|
||||
- gcloud auth application-default login (for ADC authentication)
|
||||
- Files are uploaded to GCS via litellm.files.create_file, then imported into RAG corpus
|
||||
- GCS bucket must be provided via gcs_bucket or GCS_BUCKET_NAME env var
|
||||
"""
|
||||
|
||||
custom_llm_provider: Literal["vertex_ai"]
|
||||
vector_store_id: str # RAG corpus ID (required for Vertex AI)
|
||||
|
||||
# GCP config
|
||||
vertex_project: Optional[str] # GCP project ID (uses env VERTEXAI_PROJECT if not set)
|
||||
vertex_location: Optional[str] # GCP region (default: us-central1)
|
||||
vertex_credentials: Optional[str] # Path to credentials JSON (uses ADC if not set)
|
||||
gcs_bucket: Optional[str] # GCS bucket for file uploads (uses env GCS_BUCKET_NAME if not set)
|
||||
|
||||
# Import settings
|
||||
wait_for_import: Optional[bool] # Wait for import to complete (default: True)
|
||||
import_timeout: Optional[int] # Timeout in seconds (default: 600)
|
||||
|
||||
|
||||
# Union type for vector store options
|
||||
RAGIngestVectorStoreOptions = Union[OpenAIVectorStoreOptions, BedrockVectorStoreOptions]
|
||||
RAGIngestVectorStoreOptions = Union[
|
||||
OpenAIVectorStoreOptions, BedrockVectorStoreOptions, VertexAIVectorStoreOptions
|
||||
]
|
||||
|
||||
|
||||
class RAGIngestOptions(TypedDict, total=False):
|
||||
|
||||
130
tests/vector_store_tests/rag/test_rag_vertex_ai.py
Normal file
130
tests/vector_store_tests/rag/test_rag_vertex_ai.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Vertex AI RAG Engine ingestion tests.
|
||||
|
||||
Requires:
|
||||
- gcloud auth application-default login (for ADC authentication)
|
||||
|
||||
Environment variables:
|
||||
- VERTEX_PROJECT: GCP project ID (required)
|
||||
- VERTEX_LOCATION: GCP region (optional, defaults to europe-west1)
|
||||
- VERTEX_CORPUS_ID: Existing RAG corpus ID (required for Vertex AI)
|
||||
- GCS_BUCKET_NAME: GCS bucket for file uploads (required)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
|
||||
import litellm
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
from tests.vector_store_tests.rag.base_rag_tests import BaseRAGTest
|
||||
|
||||
|
||||
class TestRAGVertexAI(BaseRAGTest):
|
||||
"""Test RAG Ingest with Vertex AI RAG Engine."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def check_env_vars(self):
|
||||
"""Check required environment variables before each test."""
|
||||
vertex_project = os.environ.get("VERTEX_PROJECT")
|
||||
corpus_id = os.environ.get("VERTEX_CORPUS_ID")
|
||||
gcs_bucket = os.environ.get("GCS_BUCKET_NAME")
|
||||
|
||||
if not vertex_project:
|
||||
pytest.skip("Skipping Vertex AI test: VERTEX_PROJECT required")
|
||||
|
||||
if not corpus_id:
|
||||
pytest.skip("Skipping Vertex AI test: VERTEX_CORPUS_ID required")
|
||||
|
||||
if not gcs_bucket:
|
||||
pytest.skip("Skipping Vertex AI test: GCS_BUCKET_NAME required")
|
||||
|
||||
# Check if vertexai is installed
|
||||
try:
|
||||
from vertexai import rag
|
||||
except ImportError:
|
||||
pytest.skip("Skipping Vertex AI test: google-cloud-aiplatform>=1.60.0 required")
|
||||
|
||||
def get_base_ingest_options(self) -> RAGIngestOptions:
|
||||
"""
|
||||
Return Vertex AI-specific ingest options.
|
||||
|
||||
Chunking is configured via chunking_strategy (unified interface),
|
||||
not inside vector_store.
|
||||
"""
|
||||
corpus_id = os.environ.get("VERTEX_CORPUS_ID")
|
||||
vertex_project = os.environ.get("VERTEX_PROJECT")
|
||||
vertex_location = os.environ.get("VERTEX_LOCATION", "europe-west1")
|
||||
gcs_bucket = os.environ.get("GCS_BUCKET_NAME")
|
||||
|
||||
return {
|
||||
"chunking_strategy": {
|
||||
"chunk_size": 512,
|
||||
"chunk_overlap": 100,
|
||||
},
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "vertex_ai",
|
||||
"vertex_project": vertex_project,
|
||||
"vertex_location": vertex_location,
|
||||
"vector_store_id": corpus_id,
|
||||
"gcs_bucket": gcs_bucket,
|
||||
"wait_for_import": True,
|
||||
},
|
||||
}
|
||||
|
||||
async def query_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Query Vertex AI RAG corpus."""
|
||||
try:
|
||||
from vertexai import init as vertexai_init
|
||||
from vertexai import rag
|
||||
except ImportError:
|
||||
pytest.skip("vertexai required for Vertex AI tests")
|
||||
|
||||
vertex_project = os.environ.get("VERTEX_PROJECT")
|
||||
vertex_location = os.environ.get("VERTEX_LOCATION", "europe-west1")
|
||||
|
||||
# Initialize Vertex AI
|
||||
vertexai_init(project=vertex_project, location=vertex_location)
|
||||
|
||||
# Build corpus name
|
||||
corpus_name = f"projects/{vertex_project}/locations/{vertex_location}/ragCorpora/{vector_store_id}"
|
||||
|
||||
# Query the corpus
|
||||
response = rag.retrieval_query(
|
||||
rag_resources=[
|
||||
rag.RagResource(rag_corpus=corpus_name)
|
||||
],
|
||||
text=query,
|
||||
rag_retrieval_config=rag.RagRetrievalConfig(
|
||||
top_k=5,
|
||||
),
|
||||
)
|
||||
|
||||
if hasattr(response, 'contexts') and response.contexts.contexts:
|
||||
# Convert to dict format
|
||||
results = []
|
||||
for ctx in response.contexts.contexts:
|
||||
results.append({
|
||||
"text": ctx.text,
|
||||
"score": ctx.score,
|
||||
"source_uri": ctx.source_uri,
|
||||
})
|
||||
|
||||
# Check if query terms appear in results
|
||||
for result in results:
|
||||
if query.lower() in result["text"].lower():
|
||||
return {"results": results}
|
||||
|
||||
# Return results even if exact match not found
|
||||
return {"results": results}
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user