diff --git a/.gitignore b/.gitignore
index e1045032d4..aa973201fd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -99,3 +99,4 @@ litellm/proxy/to_delete_loadtest_work/*
update_model_cost_map.py
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
litellm/proxy/_experimental/out/guardrails/index.html
+scripts/test_vertex_ai_search.py
diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py
index f99310bbf0..56e5f0c948 100644
--- a/litellm/llms/custom_httpx/llm_http_handler.py
+++ b/litellm/llms/custom_httpx/llm_http_handler.py
@@ -44,13 +44,8 @@ from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig, OCRResponse
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
-from litellm.llms.base_llm.search.transformation import (
- BaseSearchConfig,
- SearchResponse,
-)
-from litellm.llms.base_llm.text_to_speech.transformation import (
- BaseTextToSpeechConfig,
-)
+from litellm.llms.base_llm.search.transformation import BaseSearchConfig, SearchResponse
+from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@@ -1313,8 +1308,10 @@ class BaseLLMHTTPHandler:
# Data is always a dict for Mistral OCR format
if not isinstance(transformed_result.data, dict):
- raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
-
+ raise ValueError(
+ f"Expected dict data for OCR request, got {type(transformed_result.data)}"
+ )
+
data = transformed_result.data
## LOGGING
@@ -1377,8 +1374,10 @@ class BaseLLMHTTPHandler:
# Data is always a dict for Mistral OCR format
if not isinstance(transformed_result.data, dict):
- raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
-
+ raise ValueError(
+ f"Expected dict data for OCR request, got {type(transformed_result.data)}"
+ )
+
data = transformed_result.data
## LOGGING
@@ -1549,7 +1548,6 @@ class BaseLLMHTTPHandler:
logging_obj=logging_obj,
)
-
def search(
self,
query: Union[str, List[str]],
diff --git a/litellm/llms/vertex_ai/vector_stores/__init__.py b/litellm/llms/vertex_ai/vector_stores/__init__.py
index f3c210a973..98da2c581a 100644
--- a/litellm/llms/vertex_ai/vector_stores/__init__.py
+++ b/litellm/llms/vertex_ai/vector_stores/__init__.py
@@ -1,3 +1,4 @@
-from .transformation import VertexVectorStoreConfig
+from .rag_api.transformation import VertexVectorStoreConfig
+from .search_api.transformation import VertexSearchAPIVectorStoreConfig
-__all__ = ["VertexVectorStoreConfig"]
\ No newline at end of file
+__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]
diff --git a/litellm/llms/vertex_ai/vector_stores/transformation.py b/litellm/llms/vertex_ai/vector_stores/rag_api/transformation.py
similarity index 100%
rename from litellm/llms/vertex_ai/vector_stores/transformation.py
rename to litellm/llms/vertex_ai/vector_stores/rag_api/transformation.py
diff --git a/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py b/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py
new file mode 100644
index 0000000000..3d13c99840
--- /dev/null
+++ b/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py
@@ -0,0 +1,241 @@
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import httpx
+
+from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
+from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+from litellm.types.router import GenericLiteLLMParams
+from litellm.types.vector_stores import (
+ VectorStoreCreateOptionalRequestParams,
+ VectorStoreCreateResponse,
+ VectorStoreResultContent,
+ VectorStoreSearchOptionalRequestParams,
+ VectorStoreSearchResponse,
+ VectorStoreSearchResult,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
+ """
+ Configuration for Vertex AI Search API Vector Store
+
+ This implementation uses the Vertex AI Search API for vector store operations.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def validate_environment(
+ self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
+ ) -> dict:
+ """
+ Validate and set up authentication for Vertex AI RAG API
+ """
+ litellm_params = litellm_params or GenericLiteLLMParams()
+
+ # Get credentials and project info
+ vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
+ vertex_project = self.get_vertex_ai_project(dict(litellm_params))
+
+ # Get access token using the base class method
+ access_token, project_id = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ headers.update(
+ {
+ "Authorization": f"Bearer {access_token}",
+ "Content-Type": "application/json",
+ }
+ )
+
+ return headers
+
+ def get_complete_url(
+ self,
+ api_base: Optional[str],
+ litellm_params: dict,
+ ) -> str:
+ """
+ Get the Base endpoint for Vertex AI Search API
+ """
+ vertex_location = self.get_vertex_ai_location(litellm_params)
+ vertex_project = self.get_vertex_ai_project(litellm_params)
+ engine_id = litellm_params.get("vector_store_id")
+ collection_id = (
+ litellm_params.get("vertex_collection_id") or "default_collection"
+ )
+ if api_base:
+ return api_base.rstrip("/")
+
+ # Vertex AI Search API endpoint for search
+ return (
+ f"https://discoveryengine.googleapis.com/v1/"
+ f"projects/{vertex_project}/locations/{vertex_location}/"
+ f"collections/{collection_id}/engines/{engine_id}/servingConfigs/default_config"
+ )
+
+ def transform_search_vector_store_request(
+ self,
+ vector_store_id: str,
+ query: Union[str, List[str]],
+ vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
+ api_base: str,
+ litellm_logging_obj: LiteLLMLoggingObj,
+ litellm_params: dict,
+ ) -> Tuple[str, Dict[str, Any]]:
+ """
+ Transform search request for Vertex AI RAG API
+ """
+ # Convert query to string if it's a list
+ if isinstance(query, list):
+ query = " ".join(query)
+
+ # Vertex AI RAG API endpoint for retrieving contexts
+ url = f"{api_base}:search"
+
+ # Construct full rag corpus path
+ # Build the request body for Vertex AI Search API
+ request_body = {"query": query, "pageSize": 10}
+
+ #########################################################
+ # Update logging object with details of the request
+ #########################################################
+ litellm_logging_obj.model_call_details["query"] = query
+
+ return url, request_body
+
+ def transform_search_vector_store_response(
+ self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
+ ) -> VectorStoreSearchResponse:
+ """
+ Transform Vertex AI Search API response to standard vector store search response
+
+ Handles the format from Discovery Engine Search API which returns:
+ {
+ "results": [
+ {
+ "id": "...",
+ "document": {
+ "derivedStructData": {
+ "title": "...",
+ "link": "...",
+ "snippets": [...]
+ }
+ }
+ }
+ ]
+ }
+ """
+ try:
+ response_json = response.json()
+
+ # Extract results from Vertex AI Search API response
+ results = response_json.get("results", [])
+
+ # Transform results to standard format
+ search_results: List[VectorStoreSearchResult] = []
+ for result in results:
+ document = result.get("document", {})
+ derived_data = document.get("derivedStructData", {})
+
+ # Extract text content from snippets
+ snippets = derived_data.get("snippets", [])
+ text_content = ""
+
+ if snippets:
+ # Combine all snippets into one text
+ text_parts = [
+ snippet.get("snippet", snippet.get("htmlSnippet", ""))
+ for snippet in snippets
+ ]
+ text_content = " ".join(text_parts)
+
+ # If no snippets, use title as fallback
+ if not text_content:
+ text_content = derived_data.get("title", "")
+
+ content = [
+ VectorStoreResultContent(
+ text=text_content,
+ type="text",
+ )
+ ]
+
+ # Extract file/document information
+ document_link = derived_data.get("link", "")
+ document_title = derived_data.get("title", "")
+ document_id = result.get("id", "")
+
+ # Use link as file_id if available, otherwise use document ID
+ file_id = document_link if document_link else document_id
+ filename = document_title if document_title else "Unknown Document"
+
+ # Build attributes with available metadata
+ attributes = {
+ "document_id": document_id,
+ }
+
+ if document_link:
+ attributes["link"] = document_link
+ if document_title:
+ attributes["title"] = document_title
+
+ # Add display link if available
+ display_link = derived_data.get("displayLink", "")
+ if display_link:
+ attributes["displayLink"] = display_link
+
+ # Add formatted URL if available
+ formatted_url = derived_data.get("formattedUrl", "")
+ if formatted_url:
+ attributes["formattedUrl"] = formatted_url
+
+ # Note: Search API doesn't provide explicit scores in the response
+ # You can use the position/rank as an implicit score
+ score = 1.0 / (
+ float(search_results.__len__() + 1)
+ ) # Decreasing score based on position
+
+ result_obj = VectorStoreSearchResult(
+ score=score,
+ content=content,
+ file_id=file_id,
+ filename=filename,
+ attributes=attributes,
+ )
+ search_results.append(result_obj)
+
+ return VectorStoreSearchResponse(
+ object="vector_store.search_results.page",
+ search_query=litellm_logging_obj.model_call_details.get("query", ""),
+ data=search_results,
+ )
+
+ except Exception as e:
+ raise self.get_error_class(
+ error_message=str(e),
+ status_code=response.status_code,
+ headers=response.headers,
+ )
+
+ def transform_create_vector_store_request(
+ self,
+ vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
+ api_base: str,
+ ) -> Tuple[str, Dict]:
+ raise NotImplementedError
+
+ def transform_create_vector_store_response(
+ self, response: httpx.Response
+ ) -> VectorStoreCreateResponse:
+ raise NotImplementedError
diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml
index 2a2e6c9940..7725446d10 100644
--- a/litellm/proxy/_new_secret_config.yaml
+++ b/litellm/proxy/_new_secret_config.yaml
@@ -3,6 +3,16 @@ model_list:
litellm_params:
model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0
+vector_store_registry:
+ - vector_store_name: "vertex-ai-litellm-website-knowledgebase"
+ litellm_params:
+ vector_store_id: "test-litellm-app_1761094730750"
+ custom_llm_provider: "vertex_ai/search_api"
+ vertex_project: "test-litellm-app"
+ vertex_location: "us-central1"
+ vector_store_description: "Vertex AI vector store for the Litellm website knowledgebase"
+ vector_store_metadata:
+ source: "https://www.litellm.com/docs"
mcp_servers:
github_mcp:
url: "https://api.githubcopilot.com/mcp"
diff --git a/litellm/types/router.py b/litellm/types/router.py
index 8790d0d558..354e97d14d 100644
--- a/litellm/types/router.py
+++ b/litellm/types/router.py
@@ -163,9 +163,6 @@ class CredentialLiteLLMParams(BaseModel):
watsonx_region_name: Optional[str] = None
-
-
-
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
"""
LiteLLM Params without 'model' arg (used across completion / assistants api)
@@ -210,6 +207,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
s3_bucket_name: Optional[str] = None
gcs_bucket_name: Optional[str] = None
+ # Vector Store Params
+ vector_store_id: Optional[str] = None
+
def __init__(
self,
custom_llm_provider: Optional[str] = None,
diff --git a/litellm/utils.py b/litellm/utils.py
index d21c09c69f..f1e56954fe 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -145,9 +145,7 @@ from litellm.llms.base_llm.google_genai.transformation import (
)
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
from litellm.llms.base_llm.search.transformation import BaseSearchConfig
-from litellm.llms.base_llm.text_to_speech.transformation import (
- BaseTextToSpeechConfig,
-)
+from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
from litellm.llms.bedrock.common_utils import BedrockModelInfo
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
@@ -914,6 +912,7 @@ def _get_wrapper_timeout(
return timeout
+
def check_coroutine(value) -> bool:
return get_coroutine_checker().is_async_callable(value)
@@ -993,9 +992,7 @@ def post_call_processing(
].message.content # type: ignore
if model_response is not None:
### POST-CALL RULES ###
- rules_obj.post_call_rules(
- input=model_response, model=model
- )
+ rules_obj.post_call_rules(input=model_response, model=model)
### JSON SCHEMA VALIDATION ###
if litellm.enable_json_schema_validation is True:
try:
@@ -1011,9 +1008,9 @@ def post_call_processing(
optional_params["response_format"],
dict,
)
- and optional_params[
- "response_format"
- ].get("json_schema")
+ and optional_params["response_format"].get(
+ "json_schema"
+ )
is not None
):
json_response_format = optional_params[
@@ -1041,9 +1038,7 @@ def post_call_processing(
if (
optional_params is not None
and "response_format" in optional_params
- and isinstance(
- optional_params["response_format"], dict
- )
+ and isinstance(optional_params["response_format"], dict)
and "type" in optional_params["response_format"]
and optional_params["response_format"]["type"]
== "json_object"
@@ -1077,7 +1072,6 @@ def post_call_processing(
def client(original_function): # noqa: PLR0915
rules_obj = Rules()
-
@wraps(original_function)
def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first
@@ -5000,7 +4994,9 @@ def _get_model_info_helper( # noqa: PLR0915
tpm=_model_info.get("tpm", None),
rpm=_model_info.get("rpm", None),
ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None),
- annotation_cost_per_page=_model_info.get("annotation_cost_per_page", None),
+ annotation_cost_per_page=_model_info.get(
+ "annotation_cost_per_page", None
+ ),
)
except Exception as e:
verbose_logger.debug(f"Error getting model info: {e}")
@@ -7226,6 +7222,7 @@ class ProviderConfigManager:
from litellm.llms.sagemaker.embedding.transformation import (
SagemakerEmbeddingConfig,
)
+
return SagemakerEmbeddingConfig.get_model_config(model)
return None
@@ -7461,6 +7458,7 @@ class ProviderConfigManager:
@staticmethod
def get_provider_vector_stores_config(
provider: LlmProviders,
+ api_type: Optional[str] = None,
) -> Optional[BaseVectorStoreConfig]:
"""
v2 vector store config, use this for new vector store integrations
@@ -7478,11 +7476,18 @@ class ProviderConfigManager:
return AzureOpenAIVectorStoreConfig()
elif litellm.LlmProviders.VERTEX_AI == provider:
- from litellm.llms.vertex_ai.vector_stores.transformation import (
- VertexVectorStoreConfig,
- )
+ if api_type == "rag_api" or api_type is None: # default to rag_api
+ from litellm.llms.vertex_ai.vector_stores.rag_api.transformation import (
+ VertexVectorStoreConfig,
+ )
- return VertexVectorStoreConfig()
+ return VertexVectorStoreConfig()
+ elif api_type == "search_api":
+ from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
+ VertexSearchAPIVectorStoreConfig,
+ )
+
+ return VertexSearchAPIVectorStoreConfig()
elif litellm.LlmProviders.BEDROCK == provider:
from litellm.llms.bedrock.vector_stores.transformation import (
BedrockVectorStoreConfig,
@@ -7640,12 +7645,8 @@ class ProviderConfigManager:
from litellm.llms.parallel_ai.search.transformation import (
ParallelAISearchConfig,
)
- from litellm.llms.perplexity.search.transformation import (
- PerplexitySearchConfig,
- )
- from litellm.llms.tavily.search.transformation import (
- TavilySearchConfig,
- )
+ from litellm.llms.perplexity.search.transformation import PerplexitySearchConfig
+ from litellm.llms.tavily.search.transformation import TavilySearchConfig
PROVIDER_TO_CONFIG_MAP = {
SearchProviders.PERPLEXITY: PerplexitySearchConfig,
diff --git a/litellm/vector_stores/main.py b/litellm/vector_stores/main.py
index 01e06e1306..3cbbfea180 100644
--- a/litellm/vector_stores/main.py
+++ b/litellm/vector_stores/main.py
@@ -1,6 +1,7 @@
"""
LiteLLM SDK Functions for Creating and Searching Vector Stores
"""
+
import asyncio
import contextvars
from functools import partial
@@ -10,6 +11,7 @@ import httpx
import litellm
from litellm.constants import request_timeout
+from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.router import GenericLiteLLMParams
@@ -42,12 +44,12 @@ def mock_vector_store_search_response(
content=[
VectorStoreResultContent(
text="This is a sample search result from the vector store.",
- type="text"
+ type="text",
)
- ]
+ ],
)
]
-
+
return VectorStoreSearchResponse(
object="vector_store.search_results.page",
search_query="sample query",
@@ -79,7 +81,7 @@ def mock_vector_store_create_response(
last_active_at=None,
metadata=None,
)
-
+
return mock_response
@@ -166,14 +168,14 @@ def create(
) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]:
"""
Create a vector store.
-
+
Args:
name: The name of the vector store.
file_ids: A list of File IDs that the vector store should use.
expires_after: The expiration policy for the vector store.
chunking_strategy: The chunking strategy used to chunk the file(s).
metadata: Set of 16 key-value pairs that can be attached to an object.
-
+
Returns:
VectorStoreCreateResponse containing the created vector store details.
"""
@@ -198,9 +200,18 @@ def create(
if custom_llm_provider is None:
custom_llm_provider = "openai"
+ api_type, custom_llm_provider, _, _ = get_llm_provider(
+ model=custom_llm_provider,
+ custom_llm_provider=None,
+ litellm_params=None,
+ )
+
# get provider config - using vector store custom logger for now
- vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
- provider=litellm.LlmProviders(custom_llm_provider),
+ vector_store_provider_config = (
+ ProviderConfigManager.get_provider_vector_stores_config(
+ provider=litellm.LlmProviders(custom_llm_provider),
+ api_type=api_type,
+ )
)
if vector_store_provider_config is None:
@@ -209,7 +220,7 @@ def create(
)
local_vars.update(kwargs)
-
+
# Get VectorStoreCreateOptionalRequestParams with only valid parameters
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_create_optional_param(
@@ -242,7 +253,7 @@ def create(
_is_async=_is_async,
client=kwargs.get("client"),
)
-
+
return response
except Exception as e:
raise litellm.exception_type(
@@ -340,7 +351,7 @@ def search(
) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]:
"""
Search a vector store for relevant chunks based on a query and file attributes filter.
-
+
Args:
vector_store_id: The ID of the vector store to search.
query: A query string or array for the search.
@@ -348,7 +359,7 @@ def search(
max_num_results: Maximum number of results to return (1-50, default 10).
ranking_options: Optional ranking options for search.
rewrite_query: Whether to rewrite the natural language query for vector search.
-
+
Returns:
VectorStoreSearchResponse containing the search results.
"""
@@ -375,7 +386,7 @@ def search(
pass
# get llm provider logic
- litellm_params = GenericLiteLLMParams(**kwargs)
+ litellm_params = GenericLiteLLMParams(vector_store_id=vector_store_id, **kwargs)
## MOCK RESPONSE LOGIC
if litellm_params.mock_response and isinstance(
@@ -390,9 +401,22 @@ def search(
if custom_llm_provider is None:
custom_llm_provider = "openai"
+ if "/" in custom_llm_provider:
+ api_type, custom_llm_provider, _, _ = get_llm_provider(
+ model=custom_llm_provider,
+ custom_llm_provider=None,
+ litellm_params=None,
+ )
+ else:
+ api_type = None
+ custom_llm_provider = custom_llm_provider
+
# get provider config - using vector store custom logger for now
- vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
- provider=litellm.LlmProviders(custom_llm_provider),
+ vector_store_provider_config = (
+ ProviderConfigManager.get_provider_vector_stores_config(
+ provider=litellm.LlmProviders(custom_llm_provider),
+ api_type=api_type,
+ )
)
if vector_store_provider_config is None:
@@ -401,7 +425,7 @@ def search(
)
local_vars.update(kwargs)
-
+
# Get VectorStoreSearchOptionalRequestParams with only valid parameters
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = (
VectorStoreRequestUtils.get_requested_vector_store_search_optional_param(
@@ -438,7 +462,7 @@ def search(
_is_async=_is_async,
client=kwargs.get("client"),
)
-
+
return response
except Exception as e:
raise litellm.exception_type(
@@ -447,4 +471,4 @@ def search(
original_exception=e,
completion_kwargs=local_vars,
extra_kwargs=kwargs,
- )
\ No newline at end of file
+ )
diff --git a/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py
new file mode 100644
index 0000000000..7cace33861
--- /dev/null
+++ b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py
@@ -0,0 +1,194 @@
+"""
+Test for Vertex AI Search API Vector Store with mocked responses
+"""
+
+import json
+import pytest
+from unittest.mock import AsyncMock, MagicMock, patch
+import litellm
+
+
+# Mock response from actual Vertex AI Search API
+MOCK_VERTEX_SEARCH_RESPONSE = {
+ "results": [
+ {
+ "id": "0",
+ "document": {
+ "name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/0",
+ "id": "0",
+ "derivedStructData": {
+ "htmlTitle": "LiteLLM - Getting Started | liteLLM",
+ "snippets": [
+ {
+ "htmlSnippet": "https://github.com/BerriAI/litellm.",
+ "snippet": "https://github.com/BerriAI/litellm.",
+ }
+ ],
+ "title": "LiteLLM - Getting Started | liteLLM",
+ "link": "https://docs.litellm.ai/docs/",
+ "displayLink": "docs.litellm.ai",
+ },
+ },
+ },
+ {
+ "id": "1",
+ "document": {
+ "name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/1",
+ "id": "1",
+ "derivedStructData": {
+ "title": "Using Vector Stores (Knowledge Bases) | liteLLM",
+ "link": "https://docs.litellm.ai/docs/completion/knowledgebase",
+ "snippets": [
+ {
+ "snippet": "LiteLLM integrates with vector stores, allowing your models to access your organization's data for more accurate and contextually relevant responses."
+ }
+ ],
+ },
+ },
+ },
+ ],
+ "totalSize": 299,
+ "attributionToken": "mock_token",
+ "summary": {},
+}
+
+
+class TestVertexAISearchAPIVectorStore:
+ """Test Vertex AI Search API Vector Store with mocked responses"""
+
+ @pytest.mark.asyncio
+ async def test_basic_search_with_mock(self):
+ """Test basic vector search with mocked backend response"""
+
+ # Mock the HTTP response
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
+ mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
+
+ # Mock the access token method to avoid real authentication
+ with patch(
+ "litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
+ ) as mock_auth:
+ mock_auth.return_value = ("mock_token", "test-vector-store-db")
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
+ new_callable=AsyncMock,
+ ) as mock_post:
+ mock_post.return_value = mock_response
+
+ # Make the search request
+ response = await litellm.vector_stores.asearch(
+ query="what is LiteLLM?",
+ vector_store_id="test-litellm-app_1761094730750",
+ custom_llm_provider="vertex_ai/search_api",
+ vertex_project="test-vector-store-db",
+ vertex_location="us-central1",
+ )
+
+ print("Response:", json.dumps(response, indent=2, default=str))
+
+ # Validate the response structure (LiteLLM standard format)
+ assert response is not None
+ assert response["object"] == "vector_store.search_results.page"
+ assert "data" in response
+ assert len(response["data"]) > 0
+ assert "search_query" in response
+
+ # Validate first result
+ first_result = response["data"][0]
+ assert "score" in first_result
+ assert "content" in first_result
+ assert "file_id" in first_result
+ assert "filename" in first_result
+ assert "attributes" in first_result
+
+ # Validate content structure
+ assert len(first_result["content"]) > 0
+ assert first_result["content"][0]["type"] == "text"
+ assert "text" in first_result["content"][0]
+
+ # Verify the API was called
+ mock_post.assert_called_once()
+
+ # Verify the URL format
+ call_args = mock_post.call_args
+ url = call_args[1]["url"] if "url" in call_args[1] else call_args[0][0]
+ assert "discoveryengine.googleapis.com" in url
+ assert "test-vector-store-db" in url
+ assert "test-litellm-app_1761094730750" in url
+
+ def test_basic_search_sync_with_mock(self):
+ """Test basic vector search (sync) with mocked backend response"""
+
+ # Mock the HTTP response
+ mock_response = MagicMock()
+ mock_response.status_code = 200
+ mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
+ mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
+
+ # Mock the access token method to avoid real authentication
+ with patch(
+ "litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
+ ) as mock_auth:
+ mock_auth.return_value = ("mock_token", "test-vector-store-db")
+
+ with patch(
+ "litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
+ ) as mock_post:
+ mock_post.return_value = mock_response
+
+ # Make the search request
+ response = litellm.vector_stores.search(
+ query="what is LiteLLM?",
+ vector_store_id="test-litellm-app_1761094730750",
+ custom_llm_provider="vertex_ai/search_api",
+ vertex_project="test-vector-store-db",
+ vertex_location="us-central1",
+ )
+
+ print("Response:", json.dumps(response, indent=2, default=str))
+
+ # Validate the response structure (LiteLLM standard format)
+ assert response is not None
+ assert response["object"] == "vector_store.search_results.page"
+ assert "data" in response
+ assert len(response["data"]) > 0
+ assert "search_query" in response
+
+ # Validate first result structure
+ first_result = response["data"][0]
+ assert "score" in first_result
+ assert "content" in first_result
+ assert "file_id" in first_result
+ assert "filename" in first_result
+ assert "attributes" in first_result
+
+ # Validate content structure
+ assert len(first_result["content"]) > 0
+ assert first_result["content"][0]["type"] == "text"
+ assert "text" in first_result["content"][0]
+
+ # Validate attributes
+ assert "document_id" in first_result["attributes"]
+ assert "link" in first_result["attributes"]
+ assert "title" in first_result["attributes"]
+
+ # Verify the API was called
+ mock_post.assert_called_once()
+
+
+if __name__ == "__main__":
+ # Run tests
+ import asyncio
+
+ test = TestVertexAISearchAPIVectorStore()
+
+ print("Running async test...")
+ asyncio.run(test.test_basic_search_with_mock())
+
+ print("\nRunning sync test...")
+ test.test_basic_search_sync_with_mock()
+
+ print("\n✅ All tests passed!")