mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
(feat) Vector Stores: support Vertex AI Search API as vector store through LiteLLM (#15781)
* feat(vector_stores/): initial commit adding Vertex AI Search API support for litellm new vector store provider * feat(vector_store/): use vector store id for vertex ai search api * fix: transformation.py cleanup * fix: implement abstract function * fix: fix linting error * fix: main.py fix check
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -99,3 +99,4 @@ litellm/proxy/to_delete_loadtest_work/*
|
|||||||
update_model_cost_map.py
|
update_model_cost_map.py
|
||||||
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
|
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
|
||||||
litellm/proxy/_experimental/out/guardrails/index.html
|
litellm/proxy/_experimental/out/guardrails/index.html
|
||||||
|
scripts/test_vertex_ai_search.py
|
||||||
|
|||||||
@@ -44,13 +44,8 @@ from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig, OCRResponse
|
|||||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||||
from litellm.llms.base_llm.search.transformation import (
|
from litellm.llms.base_llm.search.transformation import BaseSearchConfig, SearchResponse
|
||||||
BaseSearchConfig,
|
from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
|
||||||
SearchResponse,
|
|
||||||
)
|
|
||||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
|
||||||
BaseTextToSpeechConfig,
|
|
||||||
)
|
|
||||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
@@ -1313,8 +1308,10 @@ class BaseLLMHTTPHandler:
|
|||||||
|
|
||||||
# Data is always a dict for Mistral OCR format
|
# Data is always a dict for Mistral OCR format
|
||||||
if not isinstance(transformed_result.data, dict):
|
if not isinstance(transformed_result.data, dict):
|
||||||
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
|
raise ValueError(
|
||||||
|
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
|
||||||
|
)
|
||||||
|
|
||||||
data = transformed_result.data
|
data = transformed_result.data
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
@@ -1377,8 +1374,10 @@ class BaseLLMHTTPHandler:
|
|||||||
|
|
||||||
# Data is always a dict for Mistral OCR format
|
# Data is always a dict for Mistral OCR format
|
||||||
if not isinstance(transformed_result.data, dict):
|
if not isinstance(transformed_result.data, dict):
|
||||||
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
|
raise ValueError(
|
||||||
|
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
|
||||||
|
)
|
||||||
|
|
||||||
data = transformed_result.data
|
data = transformed_result.data
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
@@ -1549,7 +1548,6 @@ class BaseLLMHTTPHandler:
|
|||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Union[str, List[str]],
|
query: Union[str, List[str]],
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from .transformation import VertexVectorStoreConfig
|
from .rag_api.transformation import VertexVectorStoreConfig
|
||||||
|
from .search_api.transformation import VertexSearchAPIVectorStoreConfig
|
||||||
|
|
||||||
__all__ = ["VertexVectorStoreConfig"]
|
__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]
|
||||||
|
|||||||
@@ -0,0 +1,241 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||||
|
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||||
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
|
from litellm.types.vector_stores import (
|
||||||
|
VectorStoreCreateOptionalRequestParams,
|
||||||
|
VectorStoreCreateResponse,
|
||||||
|
VectorStoreResultContent,
|
||||||
|
VectorStoreSearchOptionalRequestParams,
|
||||||
|
VectorStoreSearchResponse,
|
||||||
|
VectorStoreSearchResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
|
||||||
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||||
|
else:
|
||||||
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
|
class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
|
||||||
|
"""
|
||||||
|
Configuration for Vertex AI Search API Vector Store
|
||||||
|
|
||||||
|
This implementation uses the Vertex AI Search API for vector store operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def validate_environment(
|
||||||
|
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Validate and set up authentication for Vertex AI RAG API
|
||||||
|
"""
|
||||||
|
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||||
|
|
||||||
|
# Get credentials and project info
|
||||||
|
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
|
||||||
|
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
|
||||||
|
|
||||||
|
# Get access token using the base class method
|
||||||
|
access_token, project_id = self._ensure_access_token(
|
||||||
|
credentials=vertex_credentials,
|
||||||
|
project_id=vertex_project,
|
||||||
|
custom_llm_provider="vertex_ai",
|
||||||
|
)
|
||||||
|
|
||||||
|
headers.update(
|
||||||
|
{
|
||||||
|
"Authorization": f"Bearer {access_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def get_complete_url(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get the Base endpoint for Vertex AI Search API
|
||||||
|
"""
|
||||||
|
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||||
|
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||||
|
engine_id = litellm_params.get("vector_store_id")
|
||||||
|
collection_id = (
|
||||||
|
litellm_params.get("vertex_collection_id") or "default_collection"
|
||||||
|
)
|
||||||
|
if api_base:
|
||||||
|
return api_base.rstrip("/")
|
||||||
|
|
||||||
|
# Vertex AI Search API endpoint for search
|
||||||
|
return (
|
||||||
|
f"https://discoveryengine.googleapis.com/v1/"
|
||||||
|
f"projects/{vertex_project}/locations/{vertex_location}/"
|
||||||
|
f"collections/{collection_id}/engines/{engine_id}/servingConfigs/default_config"
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_search_vector_store_request(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: Union[str, List[str]],
|
||||||
|
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||||
|
api_base: str,
|
||||||
|
litellm_logging_obj: LiteLLMLoggingObj,
|
||||||
|
litellm_params: dict,
|
||||||
|
) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Transform search request for Vertex AI RAG API
|
||||||
|
"""
|
||||||
|
# Convert query to string if it's a list
|
||||||
|
if isinstance(query, list):
|
||||||
|
query = " ".join(query)
|
||||||
|
|
||||||
|
# Vertex AI RAG API endpoint for retrieving contexts
|
||||||
|
url = f"{api_base}:search"
|
||||||
|
|
||||||
|
# Construct full rag corpus path
|
||||||
|
# Build the request body for Vertex AI Search API
|
||||||
|
request_body = {"query": query, "pageSize": 10}
|
||||||
|
|
||||||
|
#########################################################
|
||||||
|
# Update logging object with details of the request
|
||||||
|
#########################################################
|
||||||
|
litellm_logging_obj.model_call_details["query"] = query
|
||||||
|
|
||||||
|
return url, request_body
|
||||||
|
|
||||||
|
def transform_search_vector_store_response(
|
||||||
|
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||||
|
) -> VectorStoreSearchResponse:
|
||||||
|
"""
|
||||||
|
Transform Vertex AI Search API response to standard vector store search response
|
||||||
|
|
||||||
|
Handles the format from Discovery Engine Search API which returns:
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"id": "...",
|
||||||
|
"document": {
|
||||||
|
"derivedStructData": {
|
||||||
|
"title": "...",
|
||||||
|
"link": "...",
|
||||||
|
"snippets": [...]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response_json = response.json()
|
||||||
|
|
||||||
|
# Extract results from Vertex AI Search API response
|
||||||
|
results = response_json.get("results", [])
|
||||||
|
|
||||||
|
# Transform results to standard format
|
||||||
|
search_results: List[VectorStoreSearchResult] = []
|
||||||
|
for result in results:
|
||||||
|
document = result.get("document", {})
|
||||||
|
derived_data = document.get("derivedStructData", {})
|
||||||
|
|
||||||
|
# Extract text content from snippets
|
||||||
|
snippets = derived_data.get("snippets", [])
|
||||||
|
text_content = ""
|
||||||
|
|
||||||
|
if snippets:
|
||||||
|
# Combine all snippets into one text
|
||||||
|
text_parts = [
|
||||||
|
snippet.get("snippet", snippet.get("htmlSnippet", ""))
|
||||||
|
for snippet in snippets
|
||||||
|
]
|
||||||
|
text_content = " ".join(text_parts)
|
||||||
|
|
||||||
|
# If no snippets, use title as fallback
|
||||||
|
if not text_content:
|
||||||
|
text_content = derived_data.get("title", "")
|
||||||
|
|
||||||
|
content = [
|
||||||
|
VectorStoreResultContent(
|
||||||
|
text=text_content,
|
||||||
|
type="text",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Extract file/document information
|
||||||
|
document_link = derived_data.get("link", "")
|
||||||
|
document_title = derived_data.get("title", "")
|
||||||
|
document_id = result.get("id", "")
|
||||||
|
|
||||||
|
# Use link as file_id if available, otherwise use document ID
|
||||||
|
file_id = document_link if document_link else document_id
|
||||||
|
filename = document_title if document_title else "Unknown Document"
|
||||||
|
|
||||||
|
# Build attributes with available metadata
|
||||||
|
attributes = {
|
||||||
|
"document_id": document_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if document_link:
|
||||||
|
attributes["link"] = document_link
|
||||||
|
if document_title:
|
||||||
|
attributes["title"] = document_title
|
||||||
|
|
||||||
|
# Add display link if available
|
||||||
|
display_link = derived_data.get("displayLink", "")
|
||||||
|
if display_link:
|
||||||
|
attributes["displayLink"] = display_link
|
||||||
|
|
||||||
|
# Add formatted URL if available
|
||||||
|
formatted_url = derived_data.get("formattedUrl", "")
|
||||||
|
if formatted_url:
|
||||||
|
attributes["formattedUrl"] = formatted_url
|
||||||
|
|
||||||
|
# Note: Search API doesn't provide explicit scores in the response
|
||||||
|
# You can use the position/rank as an implicit score
|
||||||
|
score = 1.0 / (
|
||||||
|
float(search_results.__len__() + 1)
|
||||||
|
) # Decreasing score based on position
|
||||||
|
|
||||||
|
result_obj = VectorStoreSearchResult(
|
||||||
|
score=score,
|
||||||
|
content=content,
|
||||||
|
file_id=file_id,
|
||||||
|
filename=filename,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
search_results.append(result_obj)
|
||||||
|
|
||||||
|
return VectorStoreSearchResponse(
|
||||||
|
object="vector_store.search_results.page",
|
||||||
|
search_query=litellm_logging_obj.model_call_details.get("query", ""),
|
||||||
|
data=search_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise self.get_error_class(
|
||||||
|
error_message=str(e),
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=response.headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_create_vector_store_request(
|
||||||
|
self,
|
||||||
|
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||||
|
api_base: str,
|
||||||
|
) -> Tuple[str, Dict]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def transform_create_vector_store_response(
|
||||||
|
self, response: httpx.Response
|
||||||
|
) -> VectorStoreCreateResponse:
|
||||||
|
raise NotImplementedError
|
||||||
@@ -3,6 +3,16 @@ model_list:
|
|||||||
litellm_params:
|
litellm_params:
|
||||||
model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||||
|
|
||||||
|
vector_store_registry:
|
||||||
|
- vector_store_name: "vertex-ai-litellm-website-knowledgebase"
|
||||||
|
litellm_params:
|
||||||
|
vector_store_id: "test-litellm-app_1761094730750"
|
||||||
|
custom_llm_provider: "vertex_ai/search_api"
|
||||||
|
vertex_project: "test-litellm-app"
|
||||||
|
vertex_location: "us-central1"
|
||||||
|
vector_store_description: "Vertex AI vector store for the Litellm website knowledgebase"
|
||||||
|
vector_store_metadata:
|
||||||
|
source: "https://www.litellm.com/docs"
|
||||||
mcp_servers:
|
mcp_servers:
|
||||||
github_mcp:
|
github_mcp:
|
||||||
url: "https://api.githubcopilot.com/mcp"
|
url: "https://api.githubcopilot.com/mcp"
|
||||||
|
|||||||
@@ -163,9 +163,6 @@ class CredentialLiteLLMParams(BaseModel):
|
|||||||
watsonx_region_name: Optional[str] = None
|
watsonx_region_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||||
"""
|
"""
|
||||||
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
||||||
@@ -210,6 +207,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
|||||||
s3_bucket_name: Optional[str] = None
|
s3_bucket_name: Optional[str] = None
|
||||||
gcs_bucket_name: Optional[str] = None
|
gcs_bucket_name: Optional[str] = None
|
||||||
|
|
||||||
|
# Vector Store Params
|
||||||
|
vector_store_id: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
custom_llm_provider: Optional[str] = None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
|||||||
@@ -145,9 +145,7 @@ from litellm.llms.base_llm.google_genai.transformation import (
|
|||||||
)
|
)
|
||||||
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
|
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
|
||||||
from litellm.llms.base_llm.search.transformation import BaseSearchConfig
|
from litellm.llms.base_llm.search.transformation import BaseSearchConfig
|
||||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
|
||||||
BaseTextToSpeechConfig,
|
|
||||||
)
|
|
||||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
|
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
|
||||||
@@ -914,6 +912,7 @@ def _get_wrapper_timeout(
|
|||||||
|
|
||||||
return timeout
|
return timeout
|
||||||
|
|
||||||
|
|
||||||
def check_coroutine(value) -> bool:
|
def check_coroutine(value) -> bool:
|
||||||
return get_coroutine_checker().is_async_callable(value)
|
return get_coroutine_checker().is_async_callable(value)
|
||||||
|
|
||||||
@@ -993,9 +992,7 @@ def post_call_processing(
|
|||||||
].message.content # type: ignore
|
].message.content # type: ignore
|
||||||
if model_response is not None:
|
if model_response is not None:
|
||||||
### POST-CALL RULES ###
|
### POST-CALL RULES ###
|
||||||
rules_obj.post_call_rules(
|
rules_obj.post_call_rules(input=model_response, model=model)
|
||||||
input=model_response, model=model
|
|
||||||
)
|
|
||||||
### JSON SCHEMA VALIDATION ###
|
### JSON SCHEMA VALIDATION ###
|
||||||
if litellm.enable_json_schema_validation is True:
|
if litellm.enable_json_schema_validation is True:
|
||||||
try:
|
try:
|
||||||
@@ -1011,9 +1008,9 @@ def post_call_processing(
|
|||||||
optional_params["response_format"],
|
optional_params["response_format"],
|
||||||
dict,
|
dict,
|
||||||
)
|
)
|
||||||
and optional_params[
|
and optional_params["response_format"].get(
|
||||||
"response_format"
|
"json_schema"
|
||||||
].get("json_schema")
|
)
|
||||||
is not None
|
is not None
|
||||||
):
|
):
|
||||||
json_response_format = optional_params[
|
json_response_format = optional_params[
|
||||||
@@ -1041,9 +1038,7 @@ def post_call_processing(
|
|||||||
if (
|
if (
|
||||||
optional_params is not None
|
optional_params is not None
|
||||||
and "response_format" in optional_params
|
and "response_format" in optional_params
|
||||||
and isinstance(
|
and isinstance(optional_params["response_format"], dict)
|
||||||
optional_params["response_format"], dict
|
|
||||||
)
|
|
||||||
and "type" in optional_params["response_format"]
|
and "type" in optional_params["response_format"]
|
||||||
and optional_params["response_format"]["type"]
|
and optional_params["response_format"]["type"]
|
||||||
== "json_object"
|
== "json_object"
|
||||||
@@ -1077,7 +1072,6 @@ def post_call_processing(
|
|||||||
def client(original_function): # noqa: PLR0915
|
def client(original_function): # noqa: PLR0915
|
||||||
rules_obj = Rules()
|
rules_obj = Rules()
|
||||||
|
|
||||||
|
|
||||||
@wraps(original_function)
|
@wraps(original_function)
|
||||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||||
# DO NOT MOVE THIS. It always needs to run first
|
# DO NOT MOVE THIS. It always needs to run first
|
||||||
@@ -5000,7 +4994,9 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||||||
tpm=_model_info.get("tpm", None),
|
tpm=_model_info.get("tpm", None),
|
||||||
rpm=_model_info.get("rpm", None),
|
rpm=_model_info.get("rpm", None),
|
||||||
ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None),
|
ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None),
|
||||||
annotation_cost_per_page=_model_info.get("annotation_cost_per_page", None),
|
annotation_cost_per_page=_model_info.get(
|
||||||
|
"annotation_cost_per_page", None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"Error getting model info: {e}")
|
verbose_logger.debug(f"Error getting model info: {e}")
|
||||||
@@ -7226,6 +7222,7 @@ class ProviderConfigManager:
|
|||||||
from litellm.llms.sagemaker.embedding.transformation import (
|
from litellm.llms.sagemaker.embedding.transformation import (
|
||||||
SagemakerEmbeddingConfig,
|
SagemakerEmbeddingConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SagemakerEmbeddingConfig.get_model_config(model)
|
return SagemakerEmbeddingConfig.get_model_config(model)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -7461,6 +7458,7 @@ class ProviderConfigManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_vector_stores_config(
|
def get_provider_vector_stores_config(
|
||||||
provider: LlmProviders,
|
provider: LlmProviders,
|
||||||
|
api_type: Optional[str] = None,
|
||||||
) -> Optional[BaseVectorStoreConfig]:
|
) -> Optional[BaseVectorStoreConfig]:
|
||||||
"""
|
"""
|
||||||
v2 vector store config, use this for new vector store integrations
|
v2 vector store config, use this for new vector store integrations
|
||||||
@@ -7478,11 +7476,18 @@ class ProviderConfigManager:
|
|||||||
|
|
||||||
return AzureOpenAIVectorStoreConfig()
|
return AzureOpenAIVectorStoreConfig()
|
||||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||||
from litellm.llms.vertex_ai.vector_stores.transformation import (
|
if api_type == "rag_api" or api_type is None: # default to rag_api
|
||||||
VertexVectorStoreConfig,
|
from litellm.llms.vertex_ai.vector_stores.rag_api.transformation import (
|
||||||
)
|
VertexVectorStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
return VertexVectorStoreConfig()
|
return VertexVectorStoreConfig()
|
||||||
|
elif api_type == "search_api":
|
||||||
|
from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
|
||||||
|
VertexSearchAPIVectorStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
return VertexSearchAPIVectorStoreConfig()
|
||||||
elif litellm.LlmProviders.BEDROCK == provider:
|
elif litellm.LlmProviders.BEDROCK == provider:
|
||||||
from litellm.llms.bedrock.vector_stores.transformation import (
|
from litellm.llms.bedrock.vector_stores.transformation import (
|
||||||
BedrockVectorStoreConfig,
|
BedrockVectorStoreConfig,
|
||||||
@@ -7640,12 +7645,8 @@ class ProviderConfigManager:
|
|||||||
from litellm.llms.parallel_ai.search.transformation import (
|
from litellm.llms.parallel_ai.search.transformation import (
|
||||||
ParallelAISearchConfig,
|
ParallelAISearchConfig,
|
||||||
)
|
)
|
||||||
from litellm.llms.perplexity.search.transformation import (
|
from litellm.llms.perplexity.search.transformation import PerplexitySearchConfig
|
||||||
PerplexitySearchConfig,
|
from litellm.llms.tavily.search.transformation import TavilySearchConfig
|
||||||
)
|
|
||||||
from litellm.llms.tavily.search.transformation import (
|
|
||||||
TavilySearchConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
PROVIDER_TO_CONFIG_MAP = {
|
PROVIDER_TO_CONFIG_MAP = {
|
||||||
SearchProviders.PERPLEXITY: PerplexitySearchConfig,
|
SearchProviders.PERPLEXITY: PerplexitySearchConfig,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
LiteLLM SDK Functions for Creating and Searching Vector Stores
|
LiteLLM SDK Functions for Creating and Searching Vector Stores
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextvars
|
import contextvars
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -10,6 +11,7 @@ import httpx
|
|||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.constants import request_timeout
|
from litellm.constants import request_timeout
|
||||||
|
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
@@ -42,12 +44,12 @@ def mock_vector_store_search_response(
|
|||||||
content=[
|
content=[
|
||||||
VectorStoreResultContent(
|
VectorStoreResultContent(
|
||||||
text="This is a sample search result from the vector store.",
|
text="This is a sample search result from the vector store.",
|
||||||
type="text"
|
type="text",
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
return VectorStoreSearchResponse(
|
return VectorStoreSearchResponse(
|
||||||
object="vector_store.search_results.page",
|
object="vector_store.search_results.page",
|
||||||
search_query="sample query",
|
search_query="sample query",
|
||||||
@@ -79,7 +81,7 @@ def mock_vector_store_create_response(
|
|||||||
last_active_at=None,
|
last_active_at=None,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
@@ -166,14 +168,14 @@ def create(
|
|||||||
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
|
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
|
||||||
"""
|
"""
|
||||||
Create a vector store.
|
Create a vector store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: The name of the vector store.
|
name: The name of the vector store.
|
||||||
file_ids: A list of File IDs that the vector store should use.
|
file_ids: A list of File IDs that the vector store should use.
|
||||||
expires_after: The expiration policy for the vector store.
|
expires_after: The expiration policy for the vector store.
|
||||||
chunking_strategy: The chunking strategy used to chunk the file(s).
|
chunking_strategy: The chunking strategy used to chunk the file(s).
|
||||||
metadata: Set of 16 key-value pairs that can be attached to an object.
|
metadata: Set of 16 key-value pairs that can be attached to an object.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VectorStoreCreateResponse containing the created vector store details.
|
VectorStoreCreateResponse containing the created vector store details.
|
||||||
"""
|
"""
|
||||||
@@ -198,9 +200,18 @@ def create(
|
|||||||
if custom_llm_provider is None:
|
if custom_llm_provider is None:
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
|
|
||||||
|
api_type, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=custom_llm_provider,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
litellm_params=None,
|
||||||
|
)
|
||||||
|
|
||||||
# get provider config - using vector store custom logger for now
|
# get provider config - using vector store custom logger for now
|
||||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
vector_store_provider_config = (
|
||||||
provider=litellm.LlmProviders(custom_llm_provider),
|
ProviderConfigManager.get_provider_vector_stores_config(
|
||||||
|
provider=litellm.LlmProviders(custom_llm_provider),
|
||||||
|
api_type=api_type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if vector_store_provider_config is None:
|
if vector_store_provider_config is None:
|
||||||
@@ -209,7 +220,7 @@ def create(
|
|||||||
)
|
)
|
||||||
|
|
||||||
local_vars.update(kwargs)
|
local_vars.update(kwargs)
|
||||||
|
|
||||||
# Get VectorStoreCreateOptionalRequestParams with only valid parameters
|
# Get VectorStoreCreateOptionalRequestParams with only valid parameters
|
||||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
|
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
|
||||||
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
|
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
|
||||||
@@ -242,7 +253,7 @@ def create(
|
|||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
client=kwargs.get("client"),
|
client=kwargs.get("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise litellm.exception_type(
|
raise litellm.exception_type(
|
||||||
@@ -340,7 +351,7 @@ def search(
|
|||||||
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
|
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
|
||||||
"""
|
"""
|
||||||
Search a vector store for relevant chunks based on a query and file attributes filter.
|
Search a vector store for relevant chunks based on a query and file attributes filter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_store_id: The ID of the vector store to search.
|
vector_store_id: The ID of the vector store to search.
|
||||||
query: A query string or array for the search.
|
query: A query string or array for the search.
|
||||||
@@ -348,7 +359,7 @@ def search(
|
|||||||
max_num_results: Maximum number of results to return (1-50, default 10).
|
max_num_results: Maximum number of results to return (1-50, default 10).
|
||||||
ranking_options: Optional ranking options for search.
|
ranking_options: Optional ranking options for search.
|
||||||
rewrite_query: Whether to rewrite the natural language query for vector search.
|
rewrite_query: Whether to rewrite the natural language query for vector search.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
VectorStoreSearchResponse containing the search results.
|
VectorStoreSearchResponse containing the search results.
|
||||||
"""
|
"""
|
||||||
@@ -375,7 +386,7 @@ def search(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# get llm provider logic
|
# get llm provider logic
|
||||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
litellm_params = GenericLiteLLMParams(vector_store_id=vector_store_id, **kwargs)
|
||||||
|
|
||||||
## MOCK RESPONSE LOGIC
|
## MOCK RESPONSE LOGIC
|
||||||
if litellm_params.mock_response and isinstance(
|
if litellm_params.mock_response and isinstance(
|
||||||
@@ -390,9 +401,22 @@ def search(
|
|||||||
if custom_llm_provider is None:
|
if custom_llm_provider is None:
|
||||||
custom_llm_provider = "openai"
|
custom_llm_provider = "openai"
|
||||||
|
|
||||||
|
if "/" in custom_llm_provider:
|
||||||
|
api_type, custom_llm_provider, _, _ = get_llm_provider(
|
||||||
|
model=custom_llm_provider,
|
||||||
|
custom_llm_provider=None,
|
||||||
|
litellm_params=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
api_type = None
|
||||||
|
custom_llm_provider = custom_llm_provider
|
||||||
|
|
||||||
# get provider config - using vector store custom logger for now
|
# get provider config - using vector store custom logger for now
|
||||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
vector_store_provider_config = (
|
||||||
provider=litellm.LlmProviders(custom_llm_provider),
|
ProviderConfigManager.get_provider_vector_stores_config(
|
||||||
|
provider=litellm.LlmProviders(custom_llm_provider),
|
||||||
|
api_type=api_type,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if vector_store_provider_config is None:
|
if vector_store_provider_config is None:
|
||||||
@@ -401,7 +425,7 @@ def search(
|
|||||||
)
|
)
|
||||||
|
|
||||||
local_vars.update(kwargs)
|
local_vars.update(kwargs)
|
||||||
|
|
||||||
# Get VectorStoreSearchOptionalRequestParams with only valid parameters
|
# Get VectorStoreSearchOptionalRequestParams with only valid parameters
|
||||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
|
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
|
||||||
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
|
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
|
||||||
@@ -438,7 +462,7 @@ def search(
|
|||||||
_is_async=_is_async,
|
_is_async=_is_async,
|
||||||
client=kwargs.get("client"),
|
client=kwargs.get("client"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise litellm.exception_type(
|
raise litellm.exception_type(
|
||||||
@@ -447,4 +471,4 @@ def search(
|
|||||||
original_exception=e,
|
original_exception=e,
|
||||||
completion_kwargs=local_vars,
|
completion_kwargs=local_vars,
|
||||||
extra_kwargs=kwargs,
|
extra_kwargs=kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,194 @@
|
|||||||
|
"""
|
||||||
|
Test for Vertex AI Search API Vector Store with mocked responses
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
# Mock response from actual Vertex AI Search API
|
||||||
|
MOCK_VERTEX_SEARCH_RESPONSE = {
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"id": "0",
|
||||||
|
"document": {
|
||||||
|
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/0",
|
||||||
|
"id": "0",
|
||||||
|
"derivedStructData": {
|
||||||
|
"htmlTitle": "<b>LiteLLM</b> - Getting Started | <b>liteLLM</b>",
|
||||||
|
"snippets": [
|
||||||
|
{
|
||||||
|
"htmlSnippet": "https://github.com/BerriAI/<b>litellm</b>.",
|
||||||
|
"snippet": "https://github.com/BerriAI/litellm.",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "LiteLLM - Getting Started | liteLLM",
|
||||||
|
"link": "https://docs.litellm.ai/docs/",
|
||||||
|
"displayLink": "docs.litellm.ai",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "1",
|
||||||
|
"document": {
|
||||||
|
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/1",
|
||||||
|
"id": "1",
|
||||||
|
"derivedStructData": {
|
||||||
|
"title": "Using Vector Stores (Knowledge Bases) | liteLLM",
|
||||||
|
"link": "https://docs.litellm.ai/docs/completion/knowledgebase",
|
||||||
|
"snippets": [
|
||||||
|
{
|
||||||
|
"snippet": "LiteLLM integrates with vector stores, allowing your models to access your organization's data for more accurate and contextually relevant responses."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"totalSize": 299,
|
||||||
|
"attributionToken": "mock_token",
|
||||||
|
"summary": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertexAISearchAPIVectorStore:
|
||||||
|
"""Test Vertex AI Search API Vector Store with mocked responses"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_search_with_mock(self):
|
||||||
|
"""Test basic vector search with mocked backend response"""
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
|
||||||
|
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
|
||||||
|
|
||||||
|
# Mock the access token method to avoid real authentication
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
|
||||||
|
) as mock_auth:
|
||||||
|
mock_auth.return_value = ("mock_token", "test-vector-store-db")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_post:
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Make the search request
|
||||||
|
response = await litellm.vector_stores.asearch(
|
||||||
|
query="what is LiteLLM?",
|
||||||
|
vector_store_id="test-litellm-app_1761094730750",
|
||||||
|
custom_llm_provider="vertex_ai/search_api",
|
||||||
|
vertex_project="test-vector-store-db",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response:", json.dumps(response, indent=2, default=str))
|
||||||
|
|
||||||
|
# Validate the response structure (LiteLLM standard format)
|
||||||
|
assert response is not None
|
||||||
|
assert response["object"] == "vector_store.search_results.page"
|
||||||
|
assert "data" in response
|
||||||
|
assert len(response["data"]) > 0
|
||||||
|
assert "search_query" in response
|
||||||
|
|
||||||
|
# Validate first result
|
||||||
|
first_result = response["data"][0]
|
||||||
|
assert "score" in first_result
|
||||||
|
assert "content" in first_result
|
||||||
|
assert "file_id" in first_result
|
||||||
|
assert "filename" in first_result
|
||||||
|
assert "attributes" in first_result
|
||||||
|
|
||||||
|
# Validate content structure
|
||||||
|
assert len(first_result["content"]) > 0
|
||||||
|
assert first_result["content"][0]["type"] == "text"
|
||||||
|
assert "text" in first_result["content"][0]
|
||||||
|
|
||||||
|
# Verify the API was called
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the URL format
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
url = call_args[1]["url"] if "url" in call_args[1] else call_args[0][0]
|
||||||
|
assert "discoveryengine.googleapis.com" in url
|
||||||
|
assert "test-vector-store-db" in url
|
||||||
|
assert "test-litellm-app_1761094730750" in url
|
||||||
|
|
||||||
|
def test_basic_search_sync_with_mock(self):
|
||||||
|
"""Test basic vector search (sync) with mocked backend response"""
|
||||||
|
|
||||||
|
# Mock the HTTP response
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
|
||||||
|
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
|
||||||
|
|
||||||
|
# Mock the access token method to avoid real authentication
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
|
||||||
|
) as mock_auth:
|
||||||
|
mock_auth.return_value = ("mock_token", "test-vector-store-db")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
|
||||||
|
) as mock_post:
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
# Make the search request
|
||||||
|
response = litellm.vector_stores.search(
|
||||||
|
query="what is LiteLLM?",
|
||||||
|
vector_store_id="test-litellm-app_1761094730750",
|
||||||
|
custom_llm_provider="vertex_ai/search_api",
|
||||||
|
vertex_project="test-vector-store-db",
|
||||||
|
vertex_location="us-central1",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Response:", json.dumps(response, indent=2, default=str))
|
||||||
|
|
||||||
|
# Validate the response structure (LiteLLM standard format)
|
||||||
|
assert response is not None
|
||||||
|
assert response["object"] == "vector_store.search_results.page"
|
||||||
|
assert "data" in response
|
||||||
|
assert len(response["data"]) > 0
|
||||||
|
assert "search_query" in response
|
||||||
|
|
||||||
|
# Validate first result structure
|
||||||
|
first_result = response["data"][0]
|
||||||
|
assert "score" in first_result
|
||||||
|
assert "content" in first_result
|
||||||
|
assert "file_id" in first_result
|
||||||
|
assert "filename" in first_result
|
||||||
|
assert "attributes" in first_result
|
||||||
|
|
||||||
|
# Validate content structure
|
||||||
|
assert len(first_result["content"]) > 0
|
||||||
|
assert first_result["content"][0]["type"] == "text"
|
||||||
|
assert "text" in first_result["content"][0]
|
||||||
|
|
||||||
|
# Validate attributes
|
||||||
|
assert "document_id" in first_result["attributes"]
|
||||||
|
assert "link" in first_result["attributes"]
|
||||||
|
assert "title" in first_result["attributes"]
|
||||||
|
|
||||||
|
# Verify the API was called
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Run tests
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
test = TestVertexAISearchAPIVectorStore()
|
||||||
|
|
||||||
|
print("Running async test...")
|
||||||
|
asyncio.run(test.test_basic_search_with_mock())
|
||||||
|
|
||||||
|
print("\nRunning sync test...")
|
||||||
|
test.test_basic_search_sync_with_mock()
|
||||||
|
|
||||||
|
print("\n✅ All tests passed!")
|
||||||
Reference in New Issue
Block a user