mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Feat] RAG API - QA - allow internal user keys to access api, allow using litellm credentials with API, raise clear exception when RAG API fails (#17169)
* allow using a cred with RAG API * add /rag/ingest to llm api routes * add rag endpoints under llm api routes * raise clear exception when RAG API fails * use async methods for bedrock ingest * fix ingestion * fix _create_opensearch_collection * fix qa check and linting
This commit is contained in:
19
document.txt
Normal file
19
document.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
LiteLLM provides a unified interface for calling 100+ different LLM providers.
|
||||
|
||||
Key capabilities:
|
||||
- Translate requests to provider-specific formats
|
||||
- Consistent OpenAI-compatible responses
|
||||
- Retry and fallback logic across deployments
|
||||
- Proxy server with authentication and rate limiting
|
||||
- Support for streaming, function calling, and embeddings
|
||||
|
||||
Popular providers supported:
|
||||
- OpenAI (GPT-4, GPT-3.5)
|
||||
- Anthropic (Claude)
|
||||
- AWS Bedrock
|
||||
- Azure OpenAI
|
||||
- Google Vertex AI
|
||||
- Cohere
|
||||
- And 95+ more
|
||||
|
||||
This allows developers to easily switch between providers without code changes.
|
||||
@@ -14,21 +14,16 @@ Key differences from OpenAI:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
|
||||
from litellm import get_secret_str
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.vertex_ai.rag_engine.transformation import VertexAIRAGTransformation
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Router
|
||||
from litellm.types.rag import RAGChunkingStrategy, RAGIngestOptions
|
||||
from litellm.types.rag import RAGIngestOptions
|
||||
|
||||
|
||||
def _get_str_or_none(value: Any) -> Optional[str]:
|
||||
|
||||
@@ -4,7 +4,7 @@ Transformation utilities for Vertex AI RAG Engine.
|
||||
Handles transforming LiteLLM's unified formats to Vertex AI RAG Engine API format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
|
||||
|
||||
@@ -20020,6 +20020,25 @@
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 159
|
||||
},
|
||||
"openrouter/anthropic/claude-opus-4.5": {
|
||||
"cache_creation_input_token_cost": 6.25e-06,
|
||||
"cache_read_input_token_cost": 5e-07,
|
||||
"input_cost_per_token": 5e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 32000,
|
||||
"max_tokens": 32000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 2.5e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 159
|
||||
},
|
||||
"openrouter/anthropic/claude-sonnet-4.5": {
|
||||
"input_cost_per_image": 0.0048,
|
||||
"cache_creation_input_token_cost": 3.75e-06,
|
||||
|
||||
@@ -379,6 +379,11 @@ class LiteLLMRoutes(enum.Enum):
|
||||
#########################################################
|
||||
passthrough_routes_wildcard = [f"{route}/*" for route in mapped_pass_through_routes]
|
||||
|
||||
litellm_native_routes = [
|
||||
"/rag/ingest",
|
||||
"/v1/rag/ingest",
|
||||
]
|
||||
|
||||
anthropic_routes = [
|
||||
"/v1/messages",
|
||||
"/v1/messages/count_tokens",
|
||||
@@ -416,6 +421,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||
+ passthrough_routes_wildcard
|
||||
+ apply_guardrail_routes
|
||||
+ mcp_routes
|
||||
+ litellm_native_routes
|
||||
)
|
||||
info_routes = [
|
||||
"/key/info",
|
||||
|
||||
@@ -61,6 +61,26 @@ class BaseRAGIngestion(ABC):
|
||||
)
|
||||
self.ingest_name = ingest_options.get("name")
|
||||
|
||||
# Load credentials from litellm_credential_name if provided in vector_store config
|
||||
self._load_credentials_from_config()
|
||||
|
||||
def _load_credentials_from_config(self) -> None:
|
||||
"""
|
||||
Load credentials from litellm_credential_name if provided in vector_store config.
|
||||
|
||||
This allows users to specify a credential name in the vector_store config
|
||||
which will be resolved from litellm.credential_list.
|
||||
"""
|
||||
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
|
||||
|
||||
credential_name = self.vector_store_config.get("litellm_credential_name")
|
||||
if credential_name and litellm.credential_list:
|
||||
credential_values = CredentialAccessor.get_credential_values(credential_name)
|
||||
# Merge credentials into vector_store_config (don't overwrite existing values)
|
||||
for key, value in credential_values.items():
|
||||
if key not in self.vector_store_config:
|
||||
self.vector_store_config[key] = value
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> str:
|
||||
"""Get the vector store provider."""
|
||||
@@ -317,5 +337,6 @@ class BaseRAGIngestion(ABC):
|
||||
status="failed",
|
||||
vector_store_id="",
|
||||
file_id=None,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@@ -11,8 +11,8 @@ Supports two modes:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -37,6 +37,29 @@ def _get_int(value: Any, default: int) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
def _normalize_principal_arn(caller_arn: str, account_id: str) -> str:
|
||||
"""
|
||||
Normalize a caller ARN to the format required by OpenSearch data access policies.
|
||||
|
||||
OpenSearch Serverless data access policies require:
|
||||
- IAM users: arn:aws:iam::account-id:user/user-name
|
||||
- IAM roles: arn:aws:iam::account-id:role/role-name
|
||||
|
||||
But get_caller_identity() returns for assumed roles:
|
||||
- arn:aws:sts::account-id:assumed-role/role-name/session-name
|
||||
|
||||
This function converts assumed-role ARNs to the proper IAM role ARN format.
|
||||
"""
|
||||
if ":assumed-role/" in caller_arn:
|
||||
# Extract role name from assumed-role ARN
|
||||
# Format: arn:aws:sts::ACCOUNT:assumed-role/ROLE-NAME/SESSION-NAME
|
||||
parts = caller_arn.split("/")
|
||||
if len(parts) >= 2:
|
||||
role_name = parts[1]
|
||||
return f"arn:aws:iam::{account_id}:role/{role_name}"
|
||||
return caller_arn
|
||||
|
||||
|
||||
class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Knowledge Base RAG ingestion.
|
||||
@@ -99,7 +122,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
# Track resources we create (for cleanup if needed)
|
||||
self._created_resources: Dict[str, Any] = {}
|
||||
|
||||
def _ensure_config_initialized(self):
|
||||
async def _ensure_config_initialized(self):
|
||||
"""Lazily initialize KB config - either detect from existing or create new."""
|
||||
if self._config_initialized:
|
||||
return
|
||||
@@ -109,7 +132,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
self._auto_detect_config()
|
||||
else:
|
||||
# No KB provided - create everything from scratch
|
||||
self._create_knowledge_base_infrastructure()
|
||||
await self._create_knowledge_base_infrastructure()
|
||||
|
||||
self._config_initialized = True
|
||||
|
||||
@@ -170,7 +193,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
)
|
||||
self.s3_bucket = self._s3_bucket
|
||||
|
||||
def _create_knowledge_base_infrastructure(self):
|
||||
async def _create_knowledge_base_infrastructure(self):
|
||||
"""Create all AWS resources needed for a new Knowledge Base."""
|
||||
verbose_logger.info("Creating new Bedrock Knowledge Base infrastructure...")
|
||||
|
||||
@@ -178,26 +201,28 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
unique_id = uuid.uuid4().hex[:8]
|
||||
kb_name = self.ingest_name or f"litellm-kb-{unique_id}"
|
||||
|
||||
# Get AWS account ID
|
||||
# Get AWS account ID and caller ARN (for data access policy)
|
||||
sts = self._get_boto3_client("sts")
|
||||
account_id = sts.get_caller_identity()["Account"]
|
||||
caller_identity = sts.get_caller_identity()
|
||||
account_id = caller_identity["Account"]
|
||||
caller_arn = caller_identity["Arn"]
|
||||
|
||||
# Step 1: Create S3 bucket (if not provided)
|
||||
self.s3_bucket = self._s3_bucket or self._create_s3_bucket(unique_id)
|
||||
|
||||
# Step 2: Create OpenSearch Serverless collection
|
||||
collection_name, collection_arn = self._create_opensearch_collection(
|
||||
unique_id, account_id
|
||||
collection_name, collection_arn = await self._create_opensearch_collection(
|
||||
unique_id, account_id, caller_arn
|
||||
)
|
||||
|
||||
# Step 3: Create OpenSearch index
|
||||
self._create_opensearch_index(collection_name)
|
||||
await self._create_opensearch_index(collection_name)
|
||||
|
||||
# Step 4: Create IAM role for Bedrock
|
||||
role_arn = self._create_bedrock_role(unique_id, account_id, collection_arn)
|
||||
role_arn = await self._create_bedrock_role(unique_id, account_id, collection_arn)
|
||||
|
||||
# Step 5: Create Knowledge Base
|
||||
self.knowledge_base_id = self._create_knowledge_base(
|
||||
self.knowledge_base_id = await self._create_knowledge_base(
|
||||
kb_name, role_arn, collection_arn
|
||||
)
|
||||
|
||||
@@ -228,8 +253,8 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
verbose_logger.info(f"Created S3 bucket: {bucket_name}")
|
||||
return bucket_name
|
||||
|
||||
def _create_opensearch_collection(
|
||||
self, unique_id: str, account_id: str
|
||||
async def _create_opensearch_collection(
|
||||
self, unique_id: str, account_id: str, caller_arn: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Create OpenSearch Serverless collection for vector storage."""
|
||||
oss = self._get_boto3_client("opensearchserverless")
|
||||
@@ -258,7 +283,16 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
}]),
|
||||
)
|
||||
|
||||
# Create data access policy
|
||||
# Create data access policy - include both root and actual caller ARN
|
||||
# This ensures the credentials being used have access to the collection
|
||||
# Normalize the caller ARN (convert assumed-role ARN to IAM role ARN if needed)
|
||||
normalized_caller_arn = _normalize_principal_arn(caller_arn, account_id)
|
||||
verbose_logger.debug(f"Caller ARN: {caller_arn}, Normalized: {normalized_caller_arn}")
|
||||
|
||||
principals = [f"arn:aws:iam::{account_id}:root", normalized_caller_arn]
|
||||
# Deduplicate in case caller is root
|
||||
principals = list(set(principals))
|
||||
|
||||
oss.create_access_policy(
|
||||
name=f"{collection_name}-access",
|
||||
type="data",
|
||||
@@ -267,7 +301,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
{"ResourceType": "index", "Resource": [f"index/{collection_name}/*"], "Permission": ["aoss:*"]},
|
||||
{"ResourceType": "collection", "Resource": [f"collection/{collection_name}"], "Permission": ["aoss:*"]},
|
||||
],
|
||||
"Principal": [f"arn:aws:iam::{account_id}:root"],
|
||||
"Principal": principals,
|
||||
}]),
|
||||
)
|
||||
|
||||
@@ -279,24 +313,29 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
collection_id = response["createCollectionDetail"]["id"]
|
||||
self._created_resources["opensearch_collection"] = collection_name
|
||||
|
||||
# Wait for collection to be active
|
||||
# Wait for collection to be active (use asyncio.sleep to avoid blocking)
|
||||
verbose_logger.debug("Waiting for OpenSearch collection to be active...")
|
||||
for _ in range(60): # 5 min timeout
|
||||
status_response = oss.batch_get_collection(ids=[collection_id])
|
||||
status = status_response["collectionDetails"][0]["status"]
|
||||
if status == "ACTIVE":
|
||||
break
|
||||
time.sleep(5)
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
raise TimeoutError("OpenSearch collection did not become active in time")
|
||||
|
||||
collection_arn = status_response["collectionDetails"][0]["arn"]
|
||||
verbose_logger.info(f"Created OpenSearch collection: {collection_name}")
|
||||
|
||||
# Wait for data access policy to propagate before returning
|
||||
# AWS recommends waiting 60+ seconds for policy propagation
|
||||
verbose_logger.debug("Waiting for data access policy to propagate (60s)...")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
return collection_name, collection_arn
|
||||
|
||||
def _create_opensearch_index(self, collection_name: str):
|
||||
"""Create vector index in OpenSearch collection."""
|
||||
async def _create_opensearch_index(self, collection_name: str):
|
||||
"""Create vector index in OpenSearch collection with retry logic."""
|
||||
from opensearchpy import OpenSearch, RequestsHttpConnection
|
||||
from requests_aws4auth import AWS4Auth
|
||||
|
||||
@@ -348,10 +387,36 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
},
|
||||
}
|
||||
|
||||
client.indices.create(index=index_name, body=index_body)
|
||||
verbose_logger.info(f"Created OpenSearch index: {index_name}")
|
||||
# Retry logic for index creation - data access policy may take time to propagate
|
||||
max_retries = 8
|
||||
retry_delay = 20 # seconds
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
client.indices.create(index=index_name, body=index_body)
|
||||
verbose_logger.info(f"Created OpenSearch index: {index_name}")
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_str = str(e)
|
||||
if "authorization_exception" in error_str.lower() or "security_exception" in error_str.lower():
|
||||
verbose_logger.warning(
|
||||
f"OpenSearch index creation attempt {attempt + 1}/{max_retries} failed due to authorization. "
|
||||
f"Waiting {retry_delay}s for policy propagation..."
|
||||
)
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
# Non-auth error, raise immediately
|
||||
raise
|
||||
|
||||
# All retries exhausted
|
||||
raise RuntimeError(
|
||||
f"Failed to create OpenSearch index after {max_retries} attempts. "
|
||||
f"Data access policy may not have propagated. Last error: {last_error}"
|
||||
)
|
||||
|
||||
def _create_bedrock_role(
|
||||
async def _create_bedrock_role(
|
||||
self, unique_id: str, account_id: str, collection_arn: str
|
||||
) -> str:
|
||||
"""Create IAM role for Bedrock KB."""
|
||||
@@ -408,13 +473,13 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
PolicyDocument=json.dumps(permissions_policy),
|
||||
)
|
||||
|
||||
# Wait for role to propagate
|
||||
time.sleep(10)
|
||||
# Wait for role to propagate (use asyncio.sleep to avoid blocking)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
verbose_logger.info(f"Created IAM role: {role_arn}")
|
||||
return role_arn
|
||||
|
||||
def _create_knowledge_base(
|
||||
async def _create_knowledge_base(
|
||||
self, kb_name: str, role_arn: str, collection_arn: str
|
||||
) -> str:
|
||||
"""Create Bedrock Knowledge Base."""
|
||||
@@ -447,14 +512,14 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
kb_id = response["knowledgeBase"]["knowledgeBaseId"]
|
||||
self._created_resources["knowledge_base"] = kb_id
|
||||
|
||||
# Wait for KB to be active
|
||||
# Wait for KB to be active (use asyncio.sleep to avoid blocking)
|
||||
verbose_logger.debug("Waiting for Knowledge Base to be active...")
|
||||
for _ in range(30):
|
||||
kb_status = bedrock_agent.get_knowledge_base(knowledgeBaseId=kb_id)
|
||||
status = kb_status["knowledgeBase"]["status"]
|
||||
if status == "ACTIVE":
|
||||
break
|
||||
time.sleep(2)
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
raise TimeoutError("Knowledge Base did not become active in time")
|
||||
|
||||
@@ -555,7 +620,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
Tuple of (knowledge_base_id, file_key)
|
||||
"""
|
||||
# Auto-detect data source and S3 bucket if needed
|
||||
self._ensure_config_initialized()
|
||||
await self._ensure_config_initialized()
|
||||
|
||||
if not file_content or not filename:
|
||||
verbose_logger.warning("No file content or filename provided for Bedrock ingestion")
|
||||
@@ -587,10 +652,11 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
job_id = ingestion_response["ingestionJob"]["ingestionJobId"]
|
||||
verbose_logger.info(f"Started ingestion job: {job_id}")
|
||||
|
||||
# Step 3: Wait for ingestion (optional)
|
||||
# Step 3: Wait for ingestion (optional) - use asyncio.sleep to avoid blocking
|
||||
if self.wait_for_ingestion:
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.ingestion_timeout:
|
||||
import time as time_module
|
||||
start_time = time_module.time()
|
||||
while time_module.time() - start_time < self.ingestion_timeout:
|
||||
job_status = bedrock_agent.get_ingestion_job(
|
||||
knowledgeBaseId=self.knowledge_base_id,
|
||||
dataSourceId=self.data_source_id,
|
||||
@@ -610,7 +676,7 @@ class BedrockRAGIngestion(BaseRAGIngestion, BaseAWSLLM):
|
||||
verbose_logger.error(f"Ingestion failed: {failure_reasons}")
|
||||
break
|
||||
elif status in ("STARTING", "IN_PROGRESS"):
|
||||
time.sleep(2)
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
verbose_logger.warning(f"Unknown ingestion status: {status}")
|
||||
break
|
||||
|
||||
@@ -78,6 +78,10 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
|
||||
vector_store_id = self.vector_store_config.get("vector_store_id")
|
||||
ttl_days = self.vector_store_config.get("ttl_days")
|
||||
|
||||
# Get credentials from vector_store_config (loaded from litellm_credential_name if provided)
|
||||
api_key = self.vector_store_config.get("api_key")
|
||||
api_base = self.vector_store_config.get("api_base")
|
||||
|
||||
# Create vector store if not provided
|
||||
if not vector_store_id:
|
||||
expires_after = {"anchor": "last_active_at", "days": ttl_days} if ttl_days else None
|
||||
@@ -85,6 +89,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
|
||||
name=self.ingest_name or "litellm-rag-ingest",
|
||||
custom_llm_provider="openai",
|
||||
expires_after=expires_after,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
vector_store_id = create_response.get("id")
|
||||
|
||||
@@ -96,6 +102,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
|
||||
file=(filename, file_content, content_type or "application/octet-stream"),
|
||||
purpose="assistants",
|
||||
custom_llm_provider="openai",
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
result_file_id = file_response.id
|
||||
|
||||
@@ -105,6 +113,8 @@ class OpenAIRAGIngestion(BaseRAGIngestion):
|
||||
file_id=result_file_id,
|
||||
custom_llm_provider="openai",
|
||||
chunking_strategy=cast(Optional[Dict[str, Any]], self.chunking_strategy),
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return vector_store_id, result_file_id
|
||||
|
||||
@@ -12,7 +12,7 @@ __all__ = ["ingest", "aingest"]
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -84,7 +84,7 @@ async def _execute_ingest_pipeline(
|
||||
provider = vector_store_config.get("custom_llm_provider", "openai")
|
||||
|
||||
# Get provider-specific ingestion class
|
||||
ingestion_class = get_rag_ingestion_class(provider)
|
||||
ingestion_class = get_ingestion_class(provider)
|
||||
|
||||
# Create ingestion instance
|
||||
ingestion = ingestion_class(
|
||||
@@ -127,7 +127,10 @@ async def aingest(
|
||||
```python
|
||||
response = await litellm.aingest(
|
||||
ingest_options={
|
||||
"vector_store": {"custom_llm_provider": "openai"}
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "openai",
|
||||
"litellm_credential_name": "my-openai-creds", # optional
|
||||
}
|
||||
},
|
||||
file_url="https://example.com/doc.pdf",
|
||||
)
|
||||
@@ -193,7 +196,10 @@ def ingest(
|
||||
```python
|
||||
response = litellm.ingest(
|
||||
ingest_options={
|
||||
"vector_store": {"custom_llm_provider": "openai"}
|
||||
"vector_store": {
|
||||
"custom_llm_provider": "openai",
|
||||
"litellm_credential_name": "my-openai-creds", # optional
|
||||
}
|
||||
},
|
||||
file_data=("doc.txt", b"Hello world", "text/plain"),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ RAG utility functions.
|
||||
Provides provider configuration utilities similar to ProviderConfigManager.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.rag.ingestion.base_ingestion import BaseRAGIngestion
|
||||
|
||||
@@ -41,12 +41,20 @@ class OpenAIVectorStoreOptions(TypedDict, total=False):
|
||||
|
||||
Example (use existing):
|
||||
{"custom_llm_provider": "openai", "vector_store_id": "vs_xxx"}
|
||||
|
||||
Example (with credentials):
|
||||
{"custom_llm_provider": "openai", "litellm_credential_name": "my-openai-creds"}
|
||||
"""
|
||||
|
||||
custom_llm_provider: Literal["openai"]
|
||||
vector_store_id: Optional[str] # Existing VS ID (auto-creates if not provided)
|
||||
ttl_days: Optional[int] # Time-to-live in days for indexed content
|
||||
|
||||
# Credentials (loaded from litellm.credential_list if litellm_credential_name is provided)
|
||||
litellm_credential_name: Optional[str] # Credential name to load from litellm.credential_list
|
||||
api_key: Optional[str] # Direct API key (alternative to litellm_credential_name)
|
||||
api_base: Optional[str] # Direct API base (alternative to litellm_credential_name)
|
||||
|
||||
|
||||
class BedrockVectorStoreOptions(TypedDict, total=False):
|
||||
"""
|
||||
@@ -58,6 +66,9 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
|
||||
Example (use existing KB):
|
||||
{"custom_llm_provider": "bedrock", "vector_store_id": "KB_ID"}
|
||||
|
||||
Example (with credentials):
|
||||
{"custom_llm_provider": "bedrock", "litellm_credential_name": "my-aws-creds"}
|
||||
|
||||
Auto-creation creates: S3 bucket, OpenSearch Serverless collection,
|
||||
IAM role, Knowledge Base, and Data Source.
|
||||
"""
|
||||
@@ -73,6 +84,9 @@ class BedrockVectorStoreOptions(TypedDict, total=False):
|
||||
wait_for_ingestion: Optional[bool] # Wait for completion (default: False - returns immediately)
|
||||
ingestion_timeout: Optional[int] # Timeout in seconds if wait_for_ingestion=True (default: 300)
|
||||
|
||||
# Credentials (loaded from litellm.credential_list if litellm_credential_name is provided)
|
||||
litellm_credential_name: Optional[str] # Credential name to load from litellm.credential_list
|
||||
|
||||
# AWS auth (uses BaseAWSLLM)
|
||||
aws_access_key_id: Optional[str]
|
||||
aws_secret_access_key: Optional[str]
|
||||
@@ -160,6 +174,7 @@ class RAGIngestResponse(TypedDict, total=False):
|
||||
status: Literal["completed", "in_progress", "failed"]
|
||||
vector_store_id: str # The vector store ID (created or existing)
|
||||
file_id: Optional[str] # The file ID in the vector store
|
||||
error: Optional[str] # Error message if status is "failed"
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user