[Feat] RAG API - QA - allow internal user keys to access api, allow using litellm credentials with API, raise clear exception when RAG API fails (#17169)

* allow using a cred with RAG API

* add /rag/ingest to llm api routes

* add rag endpoints under llm api routes

* raise clear exception when RAG API fails

* use async methods for bedrock ingest

* fix ingestion

* fix _create_opensearch_collection

* fix qa check and linting
This commit is contained in:
Ishaan Jaff
2025-11-26 17:07:30 -08:00
committed by GitHub
parent 379655e16b
commit 831694897e
11 changed files with 202 additions and 45 deletions

19
document.txt Normal file
View File

@@ -0,0 +1,19 @@
LiteLLM provides a unified interface for calling 100+ different LLM providers.
Key capabilities:
- Translate requests to provider-specific formats
- Consistent OpenAI-compatible responses
- Retry and fallback logic across deployments
- Proxy server with authentication and rate limiting
- Support for streaming, function calling, and embeddings
Popular providers supported:
- OpenAI (GPT-4, GPT-3.5)
- Anthropic (Claude)
- AWS Bedrock
- Azure OpenAI
- Google Vertex AI
- Cohere
- And 95+ more
This allows developers to easily switch between providers without code changes.

View File

@@ -14,21 +14,16 @@ Key differences from OpenAI:
from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from litellm import get_secret_str
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
from litellm.types.llms.custom_http import httpxSpecialProvider
if TYPE_CHECKING:
from litellm import Router
from litellm.types.rag import RAGChunkingStrategy, RAGIngestOptions
from litellm.types.rag import RAGIngestOptions
def _get_str_or_none(value: Any) -> Optional[str]:

View File

@@ -4,7 +4,7 @@ Transformation utilities for Vertex AI RAG Engine.
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
"""
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional
from litellm._logging import verbose_logger
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE

View File

@@ -20020,6 +20020,25 @@
"supports_vision": true,
"tool_use_system_prompt_tokens": 159
},
"openrouter/anthropic/claude-opus-4.5": {
"cache_creation_input_token_cost": 6.25e-06,
"cache_read_input_token_cost": 5e-07,
"input_cost_per_token": 5e-06,
"litellm_provider": "openrouter",
"max_input_tokens": 200000,
"max_output_tokens": 32000,
"max_tokens": 32000,
"mode": "chat",
"output_cost_per_token": 2.5e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
"supports_prompt_caching": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"tool_use_system_prompt_tokens": 159
},
"openrouter/anthropic/claude-sonnet-4.5": {
"input_cost_per_image": 0.0048,
"cache_creation_input_token_cost": 3.75e-06,

View File

@@ -379,6 +379,11 @@ class LiteLLMRoutes(enum.Enum):
#########################################################
passthrough_routes_wildcard = [f"{route}/*" for route in mapped_pass_through_routes]
litellm_native_routes = [
"/rag/ingest",
"/v1/rag/ingest",
]
anthropic_routes = [
"/v1/messages",
"/v1/messages/count_tokens",
@@ -416,6 +421,7 @@ class LiteLLMRoutes(enum.Enum):
+ passthrough_routes_wildcard
+ apply_guardrail_routes
+ mcp_routes
+ litellm_native_routes
)
info_routes = [
"/key/info",

View File

@@ -61,6 +61,26 @@ class BaseRAGIngestion(ABC):
)
self.ingest_name = ingest_options.get("name")
# Load credentials from litellm_credential_name if provided in vector_store config
self._load_credentials_from_config()
def _load_credentials_from_config(self) -> None:
"""
Load credentials from litellm_credential_name if provided in vector_store config.
This allows users to specify a credential name in the vector_store config
which will be resolved from litellm.credential_list.
"""
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
credential_name = self.vector_store_config.get("litellm_credential_name")
if credential_name and litellm.credential_list:
credential_values = CredentialAccessor.get_credential_values(credential_name)
# Merge credentials into vector_store_config (don't overwrite existing values)
for key, value in credential_values.items():
if key not in self.vector_store_config:
self.vector_store_config[key] = value
@property
def custom_llm_provider(self) -> str:
"""Get the vector store provider."""
@@ -317,5 +337,6 @@ class BaseRAGIngestion(ABC):
status="failed",
vector_store_id="",
file_id=None,
error=str(e),
)

View File

@@ -11,8 +11,8 @@ Supports two modes:
from __future__ import annotations
import asyncio
import json
import time
import uuid
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
@@ -37,6 +37,29 @@ def _get_int(value: Any, default: int) -> int:
return int(value)
def _normalize_principal_arn(caller_arn: str, account_id: str) -> str:
"""
Normalize a caller ARN to the format required by OpenSearch data access policies.
OpenSearch Serverless data access policies require:
- IAM users: arn:aws:iam::account-id:user/user-name
- IAM roles: arn:aws:iam::account-id:role/role-name
But get_caller_identity() returns for assumed roles:
- arn:aws:sts::account-id:assumed-role/role-name/session-name
This function converts assumed-role ARNs to the proper IAM role ARN format.
"""
if ":assumed-role/" in caller_arn:
# Extract role name from assumed-role ARN
# Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE-NAME/SESSION-NAME
parts = caller_arn.split("/")
if len(parts) >= 2:
role_name = parts[1]
return f"arn:aws:iam::{account_id}:role/{role_name}"
return caller_arn
class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
"""
Bedrock Knowledge Base RAG ingestion.
@@ -99,7 +122,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
# Track resources we create (for cleanup if needed)
self._created_resources: Dict[str, Any] = {}
def _ensure_config_initialized(self):
async def _ensure_config_initialized(self):
"""Lazily initialize KB config - either detect from existing or create new."""
if self._config_initialized:
return
@@ -109,7 +132,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
self._auto_detect_config()
else:
# No KB provided - create everything from scratch
self._create_knowledge_base_infrastructure()
await self._create_knowledge_base_infrastructure()
self._config_initialized = True
@@ -170,7 +193,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
)
self.s3_bucket = self._s3_bucket
def _create_knowledge_base_infrastructure(self):
async def _create_knowledge_base_infrastructure(self):
"""Create all AWS resources needed for a new Knowledge Base."""
verbose_logger.info("Creating new Bedrock Knowledge Base infrastructure...")
@@ -178,26 +201,28 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
unique_id = uuid.uuid4().hex[:8]
kb_name = self.ingest_name or f"litellm-kb-{unique_id}"
# Get AWS account ID
# Get AWS account ID and caller ARN (for data access policy)
sts = self._get_boto3_client("sts")
account_id = sts.get_caller_identity()["Account"]
caller_identity = sts.get_caller_identity()
account_id = caller_identity["Account"]
caller_arn = caller_identity["Arn"]
# Step 1: Create S3 bucket (if not provided)
self.s3_bucket = self._s3_bucket or self._create_s3_bucket(unique_id)
# Step 2: Create OpenSearch Serverless collection
collection_name, collection_arn = self._create_opensearch_collection(
unique_id, account_id
collection_name, collection_arn = await self._create_opensearch_collection(
unique_id, account_id, caller_arn
)
# Step 3: Create OpenSearch index
self._create_opensearch_index(collection_name)
await self._create_opensearch_index(collection_name)
# Step 4: Create IAM role for Bedrock
role_arn = self._create_bedrock_role(unique_id, account_id, collection_arn)
role_arn = await self._create_bedrock_role(unique_id, account_id, collection_arn)
# Step 5: Create Knowledge Base
self.knowledge_base_id = self._create_knowledge_base(
self.knowledge_base_id = await self._create_knowledge_base(
kb_name, role_arn, collection_arn
)
@@ -228,8 +253,8 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
verbose_logger.info(f"Created S3 bucket: {bucket_name}")
return bucket_name
def _create_opensearch_collection(
self, unique_id: str, account_id: str
async def _create_opensearch_collection(
self, unique_id: str, account_id: str, caller_arn: str
) -> Tuple[str, str]:
"""Create OpenSearch Serverless collection for vector storage."""
oss = self._get_boto3_client("opensearchserverless")
@@ -258,7 +283,16 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
}]),
)
# Create data access policy
# Create data access policy - include both root and actual caller ARN
# This ensures the credentials being used have access to the collection
# Normalize the caller ARN (convert assumed-role ARN to IAM role ARN if needed)
normalized_caller_arn = _normalize_principal_arn(caller_arn, account_id)
verbose_logger.debug(f"Caller ARN: {caller_arn}, Normalized: {normalized_caller_arn}")
principals = [f"arn:aws:iam::{account_id}:root", normalized_caller_arn]
# Deduplicate in case caller is root
principals = list(set(principals))
oss.create_access_policy(
name=f"{collection_name}-access",
type="data",
@@ -267,7 +301,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
{"ResourceType": "index", "Resource": [f"index/{collection_name}/*"], "Permission": ["aoss:*"]},
{"ResourceType": "collection", "Resource": [f"collection/{collection_name}"], "Permission": ["aoss:*"]},
],
"Principal": [f"arn:aws:iam::{account_id}:root"],
"Principal": principals,
}]),
)
@@ -279,24 +313,29 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
collection_id = response["createCollectionDetail"]["id"]
self._created_resources["opensearch_collection"] = collection_name
# Wait for collection to be active
# Wait for collection to be active (use asyncio.sleep to avoid blocking)
verbose_logger.debug("Waiting for OpenSearch collection to be active...")
for _ in range(60): # 5 min timeout
status_response = oss.batch_get_collection(ids=[collection_id])
status = status_response["collectionDetails"][0]["status"]
if status == "ACTIVE":
break
time.sleep(5)
await asyncio.sleep(5)
else:
raise TimeoutError("OpenSearch collection did not become active in time")
collection_arn = status_response["collectionDetails"][0]["arn"]
verbose_logger.info(f"Created OpenSearch collection: {collection_name}")
# Wait for data access policy to propagate before returning
# AWS recommends waiting 60+ seconds for policy propagation
verbose_logger.debug("Waiting for data access policy to propagate (60s)...")
await asyncio.sleep(60)
return collection_name, collection_arn
def _create_opensearch_index(self, collection_name: str):
"""Create vector index in OpenSearch collection."""
async def _create_opensearch_index(self, collection_name: str):
"""Create vector index in OpenSearch collection with retry logic."""
from opensearchpy import OpenSearch, RequestsHttpConnection
from requests_aws4auth import AWS4Auth
@@ -348,10 +387,36 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
},
}
client.indices.create(index=index_name, body=index_body)
verbose_logger.info(f"Created OpenSearch index: {index_name}")
# Retry logic for index creation - data access policy may take time to propagate
max_retries = 8
retry_delay = 20 # seconds
last_error = None
def _create_bedrock_role(
for attempt in range(max_retries):
try:
client.indices.create(index=index_name, body=index_body)
verbose_logger.info(f"Created OpenSearch index: {index_name}")
return
except Exception as e:
last_error = e
error_str = str(e)
if "authorization_exception" in error_str.lower() or "security_exception" in error_str.lower():
verbose_logger.warning(
f"OpenSearch index creation attempt {attempt + 1}/{max_retries} failed due to authorization. "
f"Waiting {retry_delay}s for policy propagation..."
)
await asyncio.sleep(retry_delay)
else:
# Non-auth error, raise immediately
raise
# All retries exhausted
raise RuntimeError(
f"Failed to create OpenSearch index after {max_retries} attempts. "
f"Data access policy may not have propagated. Last error: {last_error}"
)
async def _create_bedrock_role(
self, unique_id: str, account_id: str, collection_arn: str
) -> str:
"""Create IAM role for Bedrock KB."""
@@ -408,13 +473,13 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
PolicyDocument=json.dumps(permissions_policy),
)
# Wait for role to propagate
time.sleep(10)
# Wait for role to propagate (use asyncio.sleep to avoid blocking)
await asyncio.sleep(10)
verbose_logger.info(f"Created IAM role: {role_arn}")
return role_arn
def _create_knowledge_base(
async def _create_knowledge_base(
self, kb_name: str, role_arn: str, collection_arn: str
) -> str:
"""Create Bedrock Knowledge Base."""
@@ -447,14 +512,14 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
kb_id = response["knowledgeBase"]["knowledgeBaseId"]
self._created_resources["knowledge_base"] = kb_id
# Wait for KB to be active
# Wait for KB to be active (use asyncio.sleep to avoid blocking)
verbose_logger.debug("Waiting for Knowledge Base to be active...")
for _ in range(30):
kb_status = bedrock_agent.get_knowledge_base(knowledgeBaseId=kb_id)
status = kb_status["knowledgeBase"]["status"]
if status == "ACTIVE":
break
time.sleep(2)
await asyncio.sleep(2)
else:
raise TimeoutError("Knowledge Base did not become active in time")
@@ -555,7 +620,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
Tuple of (knowledge_base_id, file_key)
"""
# Auto-detect data source and S3 bucket if needed
self._ensure_config_initialized()
await self._ensure_config_initialized()
if not file_content or not filename:
verbose_logger.warning("No file content or filename provided for Bedrock ingestion")
@@ -587,10 +652,11 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
job_id = ingestion_response["ingestionJob"]["ingestionJobId"]
verbose_logger.info(f"Started ingestion job: {job_id}")
# Step 3: Wait for ingestion (optional)
# Step 3: Wait for ingestion (optional) - use asyncio.sleep to avoid blocking
if self.wait_for_ingestion:
start_time = time.time()
while time.time() - start_time < self.ingestion_timeout:
import time as time_module
start_time = time_module.time()
while time_module.time() - start_time < self.ingestion_timeout:
job_status = bedrock_agent.get_ingestion_job(
knowledgeBaseId=self.knowledge_base_id,
dataSourceId=self.data_source_id,
@@ -610,7 +676,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
verbose_logger.error(f"Ingestion failed: {failure_reasons}")
break
elif status in ("STARTING", "IN_PROGRESS"):
time.sleep(2)
await asyncio.sleep(2)
else:
verbose_logger.warning(f"Unknown ingestion status: {status}")
break

View File

@@ -78,6 +78,10 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
vector_store_id = self.vector_store_config.get("vector_store_id")
ttl_days = self.vector_store_config.get("ttl_days")
# Get credentials from vector_store_config (loaded from litellm_credential_name if provided)
api_key = self.vector_store_config.get("api_key")
api_base = self.vector_store_config.get("api_base")
# Create vector store if not provided
if not vector_store_id:
expires_after = {"anchor": "last_active_at", "days": ttl_days} if ttl_days else None
@@ -85,6 +89,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
name=self.ingest_name or "litellm-rag-ingest",
custom_llm_provider="openai",
expires_after=expires_after,
api_key=api_key,
api_base=api_base,
)
vector_store_id = create_response.get("id")
@@ -96,6 +102,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
file=(filename, file_content, content_type or "application/octet-stream"),
purpose="assistants",
custom_llm_provider="openai",
api_key=api_key,
api_base=api_base,
)
result_file_id = file_response.id
@@ -105,6 +113,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
file_id=result_file_id,
custom_llm_provider="openai",
chunking_strategy=cast(Optional[Dict[str, Any]], self.chunking_strategy),
api_key=api_key,
api_base=api_base,
)
return vector_store_id, result_file_id

View File

@@ -12,7 +12,7 @@ __all__ = ["ingest", "aingest"]
import asyncio
import contextvars
from functools import partial
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
import httpx
@@ -84,7 +84,7 @@ async def _execute_ingest_pipeline(
provider = vector_store_config.get("custom_llm_provider", "openai")
# Get provider-specific ingestion class
ingestion_class = get_rag_ingestion_class(provider)
ingestion_class = get_ingestion_class(provider)
# Create ingestion instance
ingestion = ingestion_class(
@@ -127,7 +127,10 @@ async def aingest(
```python
response = await litellm.aingest(
ingest_options={
"vector_store": {"custom_llm_provider": "openai"}
"vector_store": {
"custom_llm_provider": "openai",
"litellm_credential_name": "my-openai-creds", # optional
}
},
file_url="https://example.com/doc.pdf",
)
@@ -193,7 +196,10 @@ def ingest(
```python
response = litellm.ingest(
ingest_options={
"vector_store": {"custom_llm_provider": "openai"}
"vector_store": {
"custom_llm_provider": "openai",
"litellm_credential_name": "my-openai-creds", # optional
}
},
file_data=("doc.txt", b"Hello world", "text/plain"),
)

View File

@@ -4,7 +4,7 @@ RAG utility functions.
Provides provider configuration utilities similar to ProviderConfigManager.
"""
from typing import TYPE_CHECKING, Optional, Type
from typing import TYPE_CHECKING, Type
if TYPE_CHECKING:
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion

View File

@@ -41,12 +41,20 @@ class OpenAIVectorStoreOptions(TypedDict, total=False):
Example (use existing):
{"custom_llm_provider": "openai", "vector_store_id": "vs_xxx"}
Example (with credentials):
{"custom_llm_provider": "openai", "litellm_credential_name": "my-openai-creds"}
"""
custom_llm_provider: Literal["openai"]
vector_store_id: Optional[str] # Existing VS ID (auto-creates if not provided)
ttl_days: Optional[int] # Time-to-live in days for indexed content
# Credentials (loaded from litellm.credential_list if litellm_credential_name is provided)
litellm_credential_name: Optional[str] # Credential name to load from litellm.credential_list
api_key: Optional[str] # Direct API key (alternative to litellm_credential_name)
api_base: Optional[str] # Direct API base (alternative to litellm_credential_name)
class BedrockVectorStoreOptions(TypedDict, total=False):
"""
@@ -58,6 +66,9 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
Example (use existing KB):
{"custom_llm_provider": "bedrock", "vector_store_id": "KB_ID"}
Example (with credentials):
{"custom_llm_provider": "bedrock", "litellm_credential_name": "my-aws-creds"}
Auto-creation creates: S3 bucket, OpenSearch Serverless collection,
IAM role, Knowledge Base, and Data Source.
"""
@@ -73,6 +84,9 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
wait_for_ingestion: Optional[bool] # Wait for completion (default: False - returns immediately)
ingestion_timeout: Optional[int] # Timeout in seconds if wait_for_ingestion=True (default: 300)
# Credentials (loaded from litellm.credential_list if litellm_credential_name is provided)
litellm_credential_name: Optional[str] # Credential name to load from litellm.credential_list
# AWS auth (uses BaseAWSLLM)
aws_access_key_id: Optional[str]
aws_secret_access_key: Optional[str]
@@ -160,6 +174,7 @@ class RAGIngestResponse(TypedDict, total=False):
status: Literal["completed", "in_progress", "failed"]
vector_store_id: str # The vector store ID (created or existing)
file_id: Optional[str] # The file ID in the vector store
error: Optional[str] # Error message if status is "failed"