(feat) Vector Stores: support Vertex AI Search API as vector store through LiteLLM (#15781)

* feat(vector_stores/): initial commit adding Vertex AI Search API support for litellm

new vector store provider

* feat(vector_store/): use vector store id for vertex ai search api

* fix: transformation.py

cleanup

* fix: implement abstract function

* fix: fix linting error

* fix: main.py

fix check
This commit is contained in:
Krish Dholakia
2025-10-22 18:56:36 -07:00
committed by GitHub
parent 3e4b5ef3a5
commit 573306f3cd
10 changed files with 529 additions and 59 deletions

1
.gitignore vendored
View File

@@ -99,3 +99,4 @@ litellm/proxy/to_delete_loadtest_work/*
update_model_cost_map.py update_model_cost_map.py
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
litellm/proxy/_experimental/out/guardrails/index.html litellm/proxy/_experimental/out/guardrails/index.html
scripts/test_vertex_ai_search.py

View File

@@ -44,13 +44,8 @@ from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig, OCRResponse
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.llms.base_llm.search.transformation import ( from litellm.llms.base_llm.search.transformation import BaseSearchConfig, SearchResponse
BaseSearchConfig, from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
SearchResponse,
)
from litellm.llms.base_llm.text_to_speech.transformation import (
BaseTextToSpeechConfig,
)
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@@ -1313,8 +1308,10 @@ class BaseLLMHTTPHandler:
# Data is always a dict for Mistral OCR format # Data is always a dict for Mistral OCR format
if not isinstance(transformed_result.data, dict): if not isinstance(transformed_result.data, dict):
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}") raise ValueError(
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
)
data = transformed_result.data data = transformed_result.data
## LOGGING ## LOGGING
@@ -1377,8 +1374,10 @@ class BaseLLMHTTPHandler:
# Data is always a dict for Mistral OCR format # Data is always a dict for Mistral OCR format
if not isinstance(transformed_result.data, dict): if not isinstance(transformed_result.data, dict):
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}") raise ValueError(
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
)
data = transformed_result.data data = transformed_result.data
## LOGGING ## LOGGING
@@ -1549,7 +1548,6 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj, logging_obj=logging_obj,
) )
def search( def search(
self, self,
query: Union[str, List[str]], query: Union[str, List[str]],

View File

@@ -1,3 +1,4 @@
from .transformation import VertexVectorStoreConfig from .rag_api.transformation import VertexVectorStoreConfig
from .search_api.transformation import VertexSearchAPIVectorStoreConfig
__all__ = ["VertexVectorStoreConfig"] __all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]

View File

@@ -0,0 +1,241 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.router import GenericLiteLLMParams
from litellm.types.vector_stores import (
VectorStoreCreateOptionalRequestParams,
VectorStoreCreateResponse,
VectorStoreResultContent,
VectorStoreSearchOptionalRequestParams,
VectorStoreSearchResponse,
VectorStoreSearchResult,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
"""
Configuration for Vertex AI Search API Vector Store
This implementation uses the Vertex AI Search API for vector store operations.
"""
def __init__(self):
super().__init__()
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""
Validate and set up authentication for Vertex AI RAG API
"""
litellm_params = litellm_params or GenericLiteLLMParams()
# Get credentials and project info
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
# Get access token using the base class method
access_token, project_id = self._ensure_access_token(
credentials=vertex_credentials,
project_id=vertex_project,
custom_llm_provider="vertex_ai",
)
headers.update(
{
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
)
return headers
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the Base endpoint for Vertex AI Search API
"""
vertex_location = self.get_vertex_ai_location(litellm_params)
vertex_project = self.get_vertex_ai_project(litellm_params)
engine_id = litellm_params.get("vector_store_id")
collection_id = (
litellm_params.get("vertex_collection_id") or "default_collection"
)
if api_base:
return api_base.rstrip("/")
# Vertex AI Search API endpoint for search
return (
f"https://discoveryengine.googleapis.com/v1/"
f"projects/{vertex_project}/locations/{vertex_location}/"
f"collections/{collection_id}/engines/{engine_id}/servingConfigs/default_config"
)
def transform_search_vector_store_request(
self,
vector_store_id: str,
query: Union[str, List[str]],
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
api_base: str,
litellm_logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> Tuple[str, Dict[str, Any]]:
"""
Transform search request for Vertex AI RAG API
"""
# Convert query to string if it's a list
if isinstance(query, list):
query = " ".join(query)
# Vertex AI RAG API endpoint for retrieving contexts
url = f"{api_base}:search"
# Construct full rag corpus path
# Build the request body for Vertex AI Search API
request_body = {"query": query, "pageSize": 10}
#########################################################
# Update logging object with details of the request
#########################################################
litellm_logging_obj.model_call_details["query"] = query
return url, request_body
def transform_search_vector_store_response(
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
) -> VectorStoreSearchResponse:
"""
Transform Vertex AI Search API response to standard vector store search response
Handles the format from Discovery Engine Search API which returns:
{
"results": [
{
"id": "...",
"document": {
"derivedStructData": {
"title": "...",
"link": "...",
"snippets": [...]
}
}
}
]
}
"""
try:
response_json = response.json()
# Extract results from Vertex AI Search API response
results = response_json.get("results", [])
# Transform results to standard format
search_results: List[VectorStoreSearchResult] = []
for result in results:
document = result.get("document", {})
derived_data = document.get("derivedStructData", {})
# Extract text content from snippets
snippets = derived_data.get("snippets", [])
text_content = ""
if snippets:
# Combine all snippets into one text
text_parts = [
snippet.get("snippet", snippet.get("htmlSnippet", ""))
for snippet in snippets
]
text_content = " ".join(text_parts)
# If no snippets, use title as fallback
if not text_content:
text_content = derived_data.get("title", "")
content = [
VectorStoreResultContent(
text=text_content,
type="text",
)
]
# Extract file/document information
document_link = derived_data.get("link", "")
document_title = derived_data.get("title", "")
document_id = result.get("id", "")
# Use link as file_id if available, otherwise use document ID
file_id = document_link if document_link else document_id
filename = document_title if document_title else "Unknown Document"
# Build attributes with available metadata
attributes = {
"document_id": document_id,
}
if document_link:
attributes["link"] = document_link
if document_title:
attributes["title"] = document_title
# Add display link if available
display_link = derived_data.get("displayLink", "")
if display_link:
attributes["displayLink"] = display_link
# Add formatted URL if available
formatted_url = derived_data.get("formattedUrl", "")
if formatted_url:
attributes["formattedUrl"] = formatted_url
# Note: Search API doesn't provide explicit scores in the response
# You can use the position/rank as an implicit score
score = 1.0 / (
float(search_results.__len__() + 1)
) # Decreasing score based on position
result_obj = VectorStoreSearchResult(
score=score,
content=content,
file_id=file_id,
filename=filename,
attributes=attributes,
)
search_results.append(result_obj)
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query=litellm_logging_obj.model_call_details.get("query", ""),
data=search_results,
)
except Exception as e:
raise self.get_error_class(
error_message=str(e),
status_code=response.status_code,
headers=response.headers,
)
def transform_create_vector_store_request(
self,
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
api_base: str,
) -> Tuple[str, Dict]:
raise NotImplementedError
def transform_create_vector_store_response(
self, response: httpx.Response
) -> VectorStoreCreateResponse:
raise NotImplementedError

View File

@@ -3,6 +3,16 @@ model_list:
litellm_params: litellm_params:
model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0 model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0
vector_store_registry:
- vector_store_name: "vertex-ai-litellm-website-knowledgebase"
litellm_params:
vector_store_id: "test-litellm-app_1761094730750"
custom_llm_provider: "vertex_ai/search_api"
vertex_project: "test-litellm-app"
vertex_location: "us-central1"
vector_store_description: "Vertex AI vector store for the Litellm website knowledgebase"
vector_store_metadata:
source: "https://www.litellm.com/docs"
mcp_servers: mcp_servers:
github_mcp: github_mcp:
url: "https://api.githubcopilot.com/mcp" url: "https://api.githubcopilot.com/mcp"

View File

@@ -163,9 +163,6 @@ class CredentialLiteLLMParams(BaseModel):
watsonx_region_name: Optional[str] = None watsonx_region_name: Optional[str] = None
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
""" """
LiteLLM Params without 'model' arg (used across completion / assistants api) LiteLLM Params without 'model' arg (used across completion / assistants api)
@@ -210,6 +207,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
s3_bucket_name: Optional[str] = None s3_bucket_name: Optional[str] = None
gcs_bucket_name: Optional[str] = None gcs_bucket_name: Optional[str] = None
# Vector Store Params
vector_store_id: Optional[str] = None
def __init__( def __init__(
self, self,
custom_llm_provider: Optional[str] = None, custom_llm_provider: Optional[str] = None,

View File

@@ -145,9 +145,7 @@ from litellm.llms.base_llm.google_genai.transformation import (
) )
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
from litellm.llms.base_llm.search.transformation import BaseSearchConfig from litellm.llms.base_llm.search.transformation import BaseSearchConfig
from litellm.llms.base_llm.text_to_speech.transformation import ( from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
BaseTextToSpeechConfig,
)
from litellm.llms.bedrock.common_utils import BedrockModelInfo from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
@@ -914,6 +912,7 @@ def _get_wrapper_timeout(
return timeout return timeout
def check_coroutine(value) -> bool: def check_coroutine(value) -> bool:
return get_coroutine_checker().is_async_callable(value) return get_coroutine_checker().is_async_callable(value)
@@ -993,9 +992,7 @@ def post_call_processing(
].message.content # type: ignore ].message.content # type: ignore
if model_response is not None: if model_response is not None:
### POST-CALL RULES ### ### POST-CALL RULES ###
rules_obj.post_call_rules( rules_obj.post_call_rules(input=model_response, model=model)
input=model_response, model=model
)
### JSON SCHEMA VALIDATION ### ### JSON SCHEMA VALIDATION ###
if litellm.enable_json_schema_validation is True: if litellm.enable_json_schema_validation is True:
try: try:
@@ -1011,9 +1008,9 @@ def post_call_processing(
optional_params["response_format"], optional_params["response_format"],
dict, dict,
) )
and optional_params[ and optional_params["response_format"].get(
"response_format" "json_schema"
].get("json_schema") )
is not None is not None
): ):
json_response_format = optional_params[ json_response_format = optional_params[
@@ -1041,9 +1038,7 @@ def post_call_processing(
if ( if (
optional_params is not None optional_params is not None
and "response_format" in optional_params and "response_format" in optional_params
and isinstance( and isinstance(optional_params["response_format"], dict)
optional_params["response_format"], dict
)
and "type" in optional_params["response_format"] and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"] and optional_params["response_format"]["type"]
== "json_object" == "json_object"
@@ -1077,7 +1072,6 @@ def post_call_processing(
def client(original_function): # noqa: PLR0915 def client(original_function): # noqa: PLR0915
rules_obj = Rules() rules_obj = Rules()
@wraps(original_function) @wraps(original_function)
def wrapper(*args, **kwargs): # noqa: PLR0915 def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first # DO NOT MOVE THIS. It always needs to run first
@@ -5000,7 +4994,9 @@ def _get_model_info_helper( # noqa: PLR0915
tpm=_model_info.get("tpm", None), tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None), rpm=_model_info.get("rpm", None),
ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None), ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None),
annotation_cost_per_page=_model_info.get("annotation_cost_per_page", None), annotation_cost_per_page=_model_info.get(
"annotation_cost_per_page", None
),
) )
except Exception as e: except Exception as e:
verbose_logger.debug(f"Error getting model info: {e}") verbose_logger.debug(f"Error getting model info: {e}")
@@ -7226,6 +7222,7 @@ class ProviderConfigManager:
from litellm.llms.sagemaker.embedding.transformation import ( from litellm.llms.sagemaker.embedding.transformation import (
SagemakerEmbeddingConfig, SagemakerEmbeddingConfig,
) )
return SagemakerEmbeddingConfig.get_model_config(model) return SagemakerEmbeddingConfig.get_model_config(model)
return None return None
@@ -7461,6 +7458,7 @@ class ProviderConfigManager:
@staticmethod @staticmethod
def get_provider_vector_stores_config( def get_provider_vector_stores_config(
provider: LlmProviders, provider: LlmProviders,
api_type: Optional[str] = None,
) -> Optional[BaseVectorStoreConfig]: ) -> Optional[BaseVectorStoreConfig]:
""" """
v2 vector store config, use this for new vector store integrations v2 vector store config, use this for new vector store integrations
@@ -7478,11 +7476,18 @@ class ProviderConfigManager:
return AzureOpenAIVectorStoreConfig() return AzureOpenAIVectorStoreConfig()
elif litellm.LlmProviders.VERTEX_AI == provider: elif litellm.LlmProviders.VERTEX_AI == provider:
from litellm.llms.vertex_ai.vector_stores.transformation import ( if api_type == "rag_api" or api_type is None: # default to rag_api
VertexVectorStoreConfig, from litellm.llms.vertex_ai.vector_stores.rag_api.transformation import (
) VertexVectorStoreConfig,
)
return VertexVectorStoreConfig() return VertexVectorStoreConfig()
elif api_type == "search_api":
from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
VertexSearchAPIVectorStoreConfig,
)
return VertexSearchAPIVectorStoreConfig()
elif litellm.LlmProviders.BEDROCK == provider: elif litellm.LlmProviders.BEDROCK == provider:
from litellm.llms.bedrock.vector_stores.transformation import ( from litellm.llms.bedrock.vector_stores.transformation import (
BedrockVectorStoreConfig, BedrockVectorStoreConfig,
@@ -7640,12 +7645,8 @@ class ProviderConfigManager:
from litellm.llms.parallel_ai.search.transformation import ( from litellm.llms.parallel_ai.search.transformation import (
ParallelAISearchConfig, ParallelAISearchConfig,
) )
from litellm.llms.perplexity.search.transformation import ( from litellm.llms.perplexity.search.transformation import PerplexitySearchConfig
PerplexitySearchConfig, from litellm.llms.tavily.search.transformation import TavilySearchConfig
)
from litellm.llms.tavily.search.transformation import (
TavilySearchConfig,
)
PROVIDER_TO_CONFIG_MAP = { PROVIDER_TO_CONFIG_MAP = {
SearchProviders.PERPLEXITY: PerplexitySearchConfig, SearchProviders.PERPLEXITY: PerplexitySearchConfig,

View File

@@ -1,6 +1,7 @@
""" """
LiteLLM SDK Functions for Creating and Searching Vector Stores LiteLLM SDK Functions for Creating and Searching Vector Stores
""" """
import asyncio import asyncio
import contextvars import contextvars
from functools import partial from functools import partial
@@ -10,6 +11,7 @@ import httpx
import litellm import litellm
from litellm.constants import request_timeout from litellm.constants import request_timeout
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
@@ -42,12 +44,12 @@ def mock_vector_store_search_response(
content=[ content=[
VectorStoreResultContent( VectorStoreResultContent(
text="This is a sample search result from the vector store.", text="This is a sample search result from the vector store.",
type="text" type="text",
) )
] ],
) )
] ]
return VectorStoreSearchResponse( return VectorStoreSearchResponse(
object="vector_store.search_results.page", object="vector_store.search_results.page",
search_query="sample query", search_query="sample query",
@@ -79,7 +81,7 @@ def mock_vector_store_create_response(
last_active_at=None, last_active_at=None,
metadata=None, metadata=None,
) )
return mock_response return mock_response
@@ -166,14 +168,14 @@ def create(
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]: ) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
""" """
Create a vector store. Create a vector store.
Args: Args:
name: The name of the vector store. name: The name of the vector store.
file_ids: A list of File IDs that the vector store should use. file_ids: A list of File IDs that the vector store should use.
expires_after: The expiration policy for the vector store. expires_after: The expiration policy for the vector store.
chunking_strategy: The chunking strategy used to chunk the file(s). chunking_strategy: The chunking strategy used to chunk the file(s).
metadata: Set of 16 key-value pairs that can be attached to an object. metadata: Set of 16 key-value pairs that can be attached to an object.
Returns: Returns:
VectorStoreCreateResponse containing the created vector store details. VectorStoreCreateResponse containing the created vector store details.
""" """
@@ -198,9 +200,18 @@ def create(
if custom_llm_provider is None: if custom_llm_provider is None:
custom_llm_provider = "openai" custom_llm_provider = "openai"
api_type, custom_llm_provider, _, _ = get_llm_provider(
model=custom_llm_provider,
custom_llm_provider=None,
litellm_params=None,
)
# get provider config - using vector store custom logger for now # get provider config - using vector store custom logger for now
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( vector_store_provider_config = (
provider=litellm.LlmProviders(custom_llm_provider), ProviderConfigManager.get_provider_vector_stores_config(
provider=litellm.LlmProviders(custom_llm_provider),
api_type=api_type,
)
) )
if vector_store_provider_config is None: if vector_store_provider_config is None:
@@ -209,7 +220,7 @@ def create(
) )
local_vars.update(kwargs) local_vars.update(kwargs)
# Get VectorStoreCreateOptionalRequestParams with only valid parameters # Get VectorStoreCreateOptionalRequestParams with only valid parameters
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = ( vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param( VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
@@ -242,7 +253,7 @@ def create(
_is_async=_is_async, _is_async=_is_async,
client=kwargs.get("client"), client=kwargs.get("client"),
) )
return response return response
except Exception as e: except Exception as e:
raise litellm.exception_type( raise litellm.exception_type(
@@ -340,7 +351,7 @@ def search(
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]: ) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
""" """
Search a vector store for relevant chunks based on a query and file attributes filter. Search a vector store for relevant chunks based on a query and file attributes filter.
Args: Args:
vector_store_id: The ID of the vector store to search. vector_store_id: The ID of the vector store to search.
query: A query string or array for the search. query: A query string or array for the search.
@@ -348,7 +359,7 @@ def search(
max_num_results: Maximum number of results to return (1-50, default 10). max_num_results: Maximum number of results to return (1-50, default 10).
ranking_options: Optional ranking options for search. ranking_options: Optional ranking options for search.
rewrite_query: Whether to rewrite the natural language query for vector search. rewrite_query: Whether to rewrite the natural language query for vector search.
Returns: Returns:
VectorStoreSearchResponse containing the search results. VectorStoreSearchResponse containing the search results.
""" """
@@ -375,7 +386,7 @@ def search(
pass pass
# get llm provider logic # get llm provider logic
litellm_params = GenericLiteLLMParams(**kwargs) litellm_params = GenericLiteLLMParams(vector_store_id=vector_store_id, **kwargs)
## MOCK RESPONSE LOGIC ## MOCK RESPONSE LOGIC
if litellm_params.mock_response and isinstance( if litellm_params.mock_response and isinstance(
@@ -390,9 +401,22 @@ def search(
if custom_llm_provider is None: if custom_llm_provider is None:
custom_llm_provider = "openai" custom_llm_provider = "openai"
if "/" in custom_llm_provider:
api_type, custom_llm_provider, _, _ = get_llm_provider(
model=custom_llm_provider,
custom_llm_provider=None,
litellm_params=None,
)
else:
api_type = None
custom_llm_provider = custom_llm_provider
# get provider config - using vector store custom logger for now # get provider config - using vector store custom logger for now
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( vector_store_provider_config = (
provider=litellm.LlmProviders(custom_llm_provider), ProviderConfigManager.get_provider_vector_stores_config(
provider=litellm.LlmProviders(custom_llm_provider),
api_type=api_type,
)
) )
if vector_store_provider_config is None: if vector_store_provider_config is None:
@@ -401,7 +425,7 @@ def search(
) )
local_vars.update(kwargs) local_vars.update(kwargs)
# Get VectorStoreSearchOptionalRequestParams with only valid parameters # Get VectorStoreSearchOptionalRequestParams with only valid parameters
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = ( vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param( VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
@@ -438,7 +462,7 @@ def search(
_is_async=_is_async, _is_async=_is_async,
client=kwargs.get("client"), client=kwargs.get("client"),
) )
return response return response
except Exception as e: except Exception as e:
raise litellm.exception_type( raise litellm.exception_type(
@@ -447,4 +471,4 @@ def search(
original_exception=e, original_exception=e,
completion_kwargs=local_vars, completion_kwargs=local_vars,
extra_kwargs=kwargs, extra_kwargs=kwargs,
) )

View File

@@ -0,0 +1,194 @@
"""
Test for Vertex AI Search API Vector Store with mocked responses
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
# Mock response from actual Vertex AI Search API
MOCK_VERTEX_SEARCH_RESPONSE = {
"results": [
{
"id": "0",
"document": {
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/0",
"id": "0",
"derivedStructData": {
"htmlTitle": "<b>LiteLLM</b> - Getting Started | <b>liteLLM</b>",
"snippets": [
{
"htmlSnippet": "https://github.com/BerriAI/<b>litellm</b>.",
"snippet": "https://github.com/BerriAI/litellm.",
}
],
"title": "LiteLLM - Getting Started | liteLLM",
"link": "https://docs.litellm.ai/docs/",
"displayLink": "docs.litellm.ai",
},
},
},
{
"id": "1",
"document": {
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/1",
"id": "1",
"derivedStructData": {
"title": "Using Vector Stores (Knowledge Bases) | liteLLM",
"link": "https://docs.litellm.ai/docs/completion/knowledgebase",
"snippets": [
{
"snippet": "LiteLLM integrates with vector stores, allowing your models to access your organization's data for more accurate and contextually relevant responses."
}
],
},
},
},
],
"totalSize": 299,
"attributionToken": "mock_token",
"summary": {},
}
class TestVertexAISearchAPIVectorStore:
"""Test Vertex AI Search API Vector Store with mocked responses"""
@pytest.mark.asyncio
async def test_basic_search_with_mock(self):
"""Test basic vector search with mocked backend response"""
# Mock the HTTP response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
# Mock the access token method to avoid real authentication
with patch(
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
) as mock_auth:
mock_auth.return_value = ("mock_token", "test-vector-store-db")
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
new_callable=AsyncMock,
) as mock_post:
mock_post.return_value = mock_response
# Make the search request
response = await litellm.vector_stores.asearch(
query="what is LiteLLM?",
vector_store_id="test-litellm-app_1761094730750",
custom_llm_provider="vertex_ai/search_api",
vertex_project="test-vector-store-db",
vertex_location="us-central1",
)
print("Response:", json.dumps(response, indent=2, default=str))
# Validate the response structure (LiteLLM standard format)
assert response is not None
assert response["object"] == "vector_store.search_results.page"
assert "data" in response
assert len(response["data"]) > 0
assert "search_query" in response
# Validate first result
first_result = response["data"][0]
assert "score" in first_result
assert "content" in first_result
assert "file_id" in first_result
assert "filename" in first_result
assert "attributes" in first_result
# Validate content structure
assert len(first_result["content"]) > 0
assert first_result["content"][0]["type"] == "text"
assert "text" in first_result["content"][0]
# Verify the API was called
mock_post.assert_called_once()
# Verify the URL format
call_args = mock_post.call_args
url = call_args[1]["url"] if "url" in call_args[1] else call_args[0][0]
assert "discoveryengine.googleapis.com" in url
assert "test-vector-store-db" in url
assert "test-litellm-app_1761094730750" in url
def test_basic_search_sync_with_mock(self):
"""Test basic vector search (sync) with mocked backend response"""
# Mock the HTTP response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
# Mock the access token method to avoid real authentication
with patch(
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
) as mock_auth:
mock_auth.return_value = ("mock_token", "test-vector-store-db")
with patch(
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
) as mock_post:
mock_post.return_value = mock_response
# Make the search request
response = litellm.vector_stores.search(
query="what is LiteLLM?",
vector_store_id="test-litellm-app_1761094730750",
custom_llm_provider="vertex_ai/search_api",
vertex_project="test-vector-store-db",
vertex_location="us-central1",
)
print("Response:", json.dumps(response, indent=2, default=str))
# Validate the response structure (LiteLLM standard format)
assert response is not None
assert response["object"] == "vector_store.search_results.page"
assert "data" in response
assert len(response["data"]) > 0
assert "search_query" in response
# Validate first result structure
first_result = response["data"][0]
assert "score" in first_result
assert "content" in first_result
assert "file_id" in first_result
assert "filename" in first_result
assert "attributes" in first_result
# Validate content structure
assert len(first_result["content"]) > 0
assert first_result["content"][0]["type"] == "text"
assert "text" in first_result["content"][0]
# Validate attributes
assert "document_id" in first_result["attributes"]
assert "link" in first_result["attributes"]
assert "title" in first_result["attributes"]
# Verify the API was called
mock_post.assert_called_once()
if __name__ == "__main__":
# Run tests
import asyncio
test = TestVertexAISearchAPIVectorStore()
print("Running async test...")
asyncio.run(test.test_basic_search_with_mock())
print("\nRunning sync test...")
test.test_basic_search_sync_with_mock()
print("\n✅ All tests passed!")