From 573306f3cd1e4d1635bbbb840503bfc574687fc7 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Wed, 22 Oct 2025 18:56:36 -0700 Subject: [PATCH] (feat) Vector Stores: support Vertex AI Search API as vector store through LiteLLM (#15781) * feat(vector_stores/): initial commit adding Vertex AI Search API support for litellm new vector store provider * feat(vector_store/): use vector store id for vertex ai search api * fix: transformation.py cleanup * fix: implement abstract function * fix: fix linting error * fix: main.py fix check --- .gitignore | 1 + litellm/llms/custom_httpx/llm_http_handler.py | 22 +- .../llms/vertex_ai/vector_stores/__init__.py | 5 +- .../{ => rag_api}/transformation.py | 0 .../search_api/transformation.py | 241 ++++++++++++++++++ litellm/proxy/_new_secret_config.yaml | 10 + litellm/types/router.py | 6 +- litellm/utils.py | 49 ++-- litellm/vector_stores/main.py | 60 +++-- .../test_vertex_ai_search_api_vector_store.py | 194 ++++++++++++++ 10 files changed, 529 insertions(+), 59 deletions(-) rename litellm/llms/vertex_ai/vector_stores/{ => rag_api}/transformation.py (100%) create mode 100644 litellm/llms/vertex_ai/vector_stores/search_api/transformation.py create mode 100644 tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py diff --git a/.gitignore b/.gitignore index e1045032d4..aa973201fd 100644 --- a/.gitignore +++ b/.gitignore @@ -99,3 +99,4 @@ litellm/proxy/to_delete_loadtest_work/* update_model_cost_map.py tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py litellm/proxy/_experimental/out/guardrails/index.html +scripts/test_vertex_ai_search.py diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index f99310bbf0..56e5f0c948 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -44,13 +44,8 @@ from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig, OCRResponse from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig -from litellm.llms.base_llm.search.transformation import ( - BaseSearchConfig, - SearchResponse, -) -from litellm.llms.base_llm.text_to_speech.transformation import ( - BaseTextToSpeechConfig, -) +from litellm.llms.base_llm.search.transformation import BaseSearchConfig, SearchResponse +from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, @@ -1313,8 +1308,10 @@ class BaseLLMHTTPHandler: # Data is always a dict for Mistral OCR format if not isinstance(transformed_result.data, dict): - raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}") - + raise ValueError( + f"Expected dict data for OCR request, got {type(transformed_result.data)}" + ) + data = transformed_result.data ## LOGGING @@ -1377,8 +1374,10 @@ class BaseLLMHTTPHandler: # Data is always a dict for Mistral OCR format if not isinstance(transformed_result.data, dict): - raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}") - + raise ValueError( + f"Expected dict data for OCR request, got {type(transformed_result.data)}" + ) + data = transformed_result.data ## LOGGING @@ -1549,7 +1548,6 @@ class BaseLLMHTTPHandler: logging_obj=logging_obj, ) - def search( self, query: Union[str, List[str]], diff --git a/litellm/llms/vertex_ai/vector_stores/__init__.py b/litellm/llms/vertex_ai/vector_stores/__init__.py index f3c210a973..98da2c581a 100644 --- a/litellm/llms/vertex_ai/vector_stores/__init__.py +++ b/litellm/llms/vertex_ai/vector_stores/__init__.py @@ -1,3 +1,4 @@ -from .transformation import VertexVectorStoreConfig +from .rag_api.transformation import VertexVectorStoreConfig +from .search_api.transformation import VertexSearchAPIVectorStoreConfig -__all__ = ["VertexVectorStoreConfig"] \ No newline at end of file +__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"] diff --git a/litellm/llms/vertex_ai/vector_stores/transformation.py b/litellm/llms/vertex_ai/vector_stores/rag_api/transformation.py similarity index 100% rename from litellm/llms/vertex_ai/vector_stores/transformation.py rename to litellm/llms/vertex_ai/vector_stores/rag_api/transformation.py diff --git a/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py b/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py new file mode 100644 index 0000000000..3d13c99840 --- /dev/null +++ b/litellm/llms/vertex_ai/vector_stores/search_api/transformation.py @@ -0,0 +1,241 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import httpx + +from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig +from litellm.llms.vertex_ai.vertex_llm_base import VertexBase +from litellm.types.router import GenericLiteLLMParams +from litellm.types.vector_stores import ( + VectorStoreCreateOptionalRequestParams, + VectorStoreCreateResponse, + VectorStoreResultContent, + VectorStoreSearchOptionalRequestParams, + VectorStoreSearchResponse, + VectorStoreSearchResult, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase): + """ + Configuration for Vertex AI Search API Vector Store + + This implementation uses the Vertex AI Search API for vector store operations. + """ + + def __init__(self): + super().__init__() + + def validate_environment( + self, headers: dict, litellm_params: Optional[GenericLiteLLMParams] + ) -> dict: + """ + Validate and set up authentication for Vertex AI RAG API + """ + litellm_params = litellm_params or GenericLiteLLMParams() + + # Get credentials and project info + vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params)) + vertex_project = self.get_vertex_ai_project(dict(litellm_params)) + + # Get access token using the base class method + access_token, project_id = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + + headers.update( + { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + ) + + return headers + + def get_complete_url( + self, + api_base: Optional[str], + litellm_params: dict, + ) -> str: + """ + Get the Base endpoint for Vertex AI Search API + """ + vertex_location = self.get_vertex_ai_location(litellm_params) + vertex_project = self.get_vertex_ai_project(litellm_params) + engine_id = litellm_params.get("vector_store_id") + collection_id = ( + litellm_params.get("vertex_collection_id") or "default_collection" + ) + if api_base: + return api_base.rstrip("/") + + # Vertex AI Search API endpoint for search + return ( + f"https://discoveryengine.googleapis.com/v1/" + f"projects/{vertex_project}/locations/{vertex_location}/" + f"collections/{collection_id}/engines/{engine_id}/servingConfigs/default_config" + ) + + def transform_search_vector_store_request( + self, + vector_store_id: str, + query: Union[str, List[str]], + vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams, + api_base: str, + litellm_logging_obj: LiteLLMLoggingObj, + litellm_params: dict, + ) -> Tuple[str, Dict[str, Any]]: + """ + Transform search request for Vertex AI RAG API + """ + # Convert query to string if it's a list + if isinstance(query, list): + query = " ".join(query) + + # Vertex AI RAG API endpoint for retrieving contexts + url = f"{api_base}:search" + + # Construct full rag corpus path + # Build the request body for Vertex AI Search API + request_body = {"query": query, "pageSize": 10} + + ######################################################### + # Update logging object with details of the request + ######################################################### + litellm_logging_obj.model_call_details["query"] = query + + return url, request_body + + def transform_search_vector_store_response( + self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj + ) -> VectorStoreSearchResponse: + """ + Transform Vertex AI Search API response to standard vector store search response + + Handles the format from Discovery Engine Search API which returns: + { + "results": [ + { + "id": "...", + "document": { + "derivedStructData": { + "title": "...", + "link": "...", + "snippets": [...] + } + } + } + ] + } + """ + try: + response_json = response.json() + + # Extract results from Vertex AI Search API response + results = response_json.get("results", []) + + # Transform results to standard format + search_results: List[VectorStoreSearchResult] = [] + for result in results: + document = result.get("document", {}) + derived_data = document.get("derivedStructData", {}) + + # Extract text content from snippets + snippets = derived_data.get("snippets", []) + text_content = "" + + if snippets: + # Combine all snippets into one text + text_parts = [ + snippet.get("snippet", snippet.get("htmlSnippet", "")) + for snippet in snippets + ] + text_content = " ".join(text_parts) + + # If no snippets, use title as fallback + if not text_content: + text_content = derived_data.get("title", "") + + content = [ + VectorStoreResultContent( + text=text_content, + type="text", + ) + ] + + # Extract file/document information + document_link = derived_data.get("link", "") + document_title = derived_data.get("title", "") + document_id = result.get("id", "") + + # Use link as file_id if available, otherwise use document ID + file_id = document_link if document_link else document_id + filename = document_title if document_title else "Unknown Document" + + # Build attributes with available metadata + attributes = { + "document_id": document_id, + } + + if document_link: + attributes["link"] = document_link + if document_title: + attributes["title"] = document_title + + # Add display link if available + display_link = derived_data.get("displayLink", "") + if display_link: + attributes["displayLink"] = display_link + + # Add formatted URL if available + formatted_url = derived_data.get("formattedUrl", "") + if formatted_url: + attributes["formattedUrl"] = formatted_url + + # Note: Search API doesn't provide explicit scores in the response + # You can use the position/rank as an implicit score + score = 1.0 / ( + float(search_results.__len__() + 1) + ) # Decreasing score based on position + + result_obj = VectorStoreSearchResult( + score=score, + content=content, + file_id=file_id, + filename=filename, + attributes=attributes, + ) + search_results.append(result_obj) + + return VectorStoreSearchResponse( + object="vector_store.search_results.page", + search_query=litellm_logging_obj.model_call_details.get("query", ""), + data=search_results, + ) + + except Exception as e: + raise self.get_error_class( + error_message=str(e), + status_code=response.status_code, + headers=response.headers, + ) + + def transform_create_vector_store_request( + self, + vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams, + api_base: str, + ) -> Tuple[str, Dict]: + raise NotImplementedError + + def transform_create_vector_store_response( + self, response: httpx.Response + ) -> VectorStoreCreateResponse: + raise NotImplementedError diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 2a2e6c9940..7725446d10 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,6 +3,16 @@ model_list: litellm_params: model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0 +vector_store_registry: + - vector_store_name: "vertex-ai-litellm-website-knowledgebase" + litellm_params: + vector_store_id: "test-litellm-app_1761094730750" + custom_llm_provider: "vertex_ai/search_api" + vertex_project: "test-litellm-app" + vertex_location: "us-central1" + vector_store_description: "Vertex AI vector store for the Litellm website knowledgebase" + vector_store_metadata: + source: "https://www.litellm.com/docs" mcp_servers: github_mcp: url: "https://api.githubcopilot.com/mcp" diff --git a/litellm/types/router.py b/litellm/types/router.py index 8790d0d558..354e97d14d 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -163,9 +163,6 @@ class CredentialLiteLLMParams(BaseModel): watsonx_region_name: Optional[str] = None - - - class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): """ LiteLLM Params without 'model' arg (used across completion / assistants api) @@ -210,6 +207,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): s3_bucket_name: Optional[str] = None gcs_bucket_name: Optional[str] = None + # Vector Store Params + vector_store_id: Optional[str] = None + def __init__( self, custom_llm_provider: Optional[str] = None, diff --git a/litellm/utils.py b/litellm/utils.py index d21c09c69f..f1e56954fe 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -145,9 +145,7 @@ from litellm.llms.base_llm.google_genai.transformation import ( ) from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig from litellm.llms.base_llm.search.transformation import BaseSearchConfig -from litellm.llms.base_llm.text_to_speech.transformation import ( - BaseTextToSpeechConfig, -) +from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig from litellm.llms.bedrock.common_utils import BedrockModelInfo from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.mistral.ocr.transformation import MistralOCRConfig @@ -914,6 +912,7 @@ def _get_wrapper_timeout( return timeout + def check_coroutine(value) -> bool: return get_coroutine_checker().is_async_callable(value) @@ -993,9 +992,7 @@ def post_call_processing( ].message.content # type: ignore if model_response is not None: ### POST-CALL RULES ### - rules_obj.post_call_rules( - input=model_response, model=model - ) + rules_obj.post_call_rules(input=model_response, model=model) ### JSON SCHEMA VALIDATION ### if litellm.enable_json_schema_validation is True: try: @@ -1011,9 +1008,9 @@ def post_call_processing( optional_params["response_format"], dict, ) - and optional_params[ - "response_format" - ].get("json_schema") + and optional_params["response_format"].get( + "json_schema" + ) is not None ): json_response_format = optional_params[ @@ -1041,9 +1038,7 @@ def post_call_processing( if ( optional_params is not None and "response_format" in optional_params - and isinstance( - optional_params["response_format"], dict - ) + and isinstance(optional_params["response_format"], dict) and "type" in optional_params["response_format"] and optional_params["response_format"]["type"] == "json_object" @@ -1077,7 +1072,6 @@ def post_call_processing( def client(original_function): # noqa: PLR0915 rules_obj = Rules() - @wraps(original_function) def wrapper(*args, **kwargs): # noqa: PLR0915 # DO NOT MOVE THIS. It always needs to run first @@ -5000,7 +4994,9 @@ def _get_model_info_helper( # noqa: PLR0915 tpm=_model_info.get("tpm", None), rpm=_model_info.get("rpm", None), ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None), - annotation_cost_per_page=_model_info.get("annotation_cost_per_page", None), + annotation_cost_per_page=_model_info.get( + "annotation_cost_per_page", None + ), ) except Exception as e: verbose_logger.debug(f"Error getting model info: {e}") @@ -7226,6 +7222,7 @@ class ProviderConfigManager: from litellm.llms.sagemaker.embedding.transformation import ( SagemakerEmbeddingConfig, ) + return SagemakerEmbeddingConfig.get_model_config(model) return None @@ -7461,6 +7458,7 @@ class ProviderConfigManager: @staticmethod def get_provider_vector_stores_config( provider: LlmProviders, + api_type: Optional[str] = None, ) -> Optional[BaseVectorStoreConfig]: """ v2 vector store config, use this for new vector store integrations @@ -7478,11 +7476,18 @@ class ProviderConfigManager: return AzureOpenAIVectorStoreConfig() elif litellm.LlmProviders.VERTEX_AI == provider: - from litellm.llms.vertex_ai.vector_stores.transformation import ( - VertexVectorStoreConfig, - ) + if api_type == "rag_api" or api_type is None: # default to rag_api + from litellm.llms.vertex_ai.vector_stores.rag_api.transformation import ( + VertexVectorStoreConfig, + ) - return VertexVectorStoreConfig() + return VertexVectorStoreConfig() + elif api_type == "search_api": + from litellm.llms.vertex_ai.vector_stores.search_api.transformation import ( + VertexSearchAPIVectorStoreConfig, + ) + + return VertexSearchAPIVectorStoreConfig() elif litellm.LlmProviders.BEDROCK == provider: from litellm.llms.bedrock.vector_stores.transformation import ( BedrockVectorStoreConfig, @@ -7640,12 +7645,8 @@ class ProviderConfigManager: from litellm.llms.parallel_ai.search.transformation import ( ParallelAISearchConfig, ) - from litellm.llms.perplexity.search.transformation import ( - PerplexitySearchConfig, - ) - from litellm.llms.tavily.search.transformation import ( - TavilySearchConfig, - ) + from litellm.llms.perplexity.search.transformation import PerplexitySearchConfig + from litellm.llms.tavily.search.transformation import TavilySearchConfig PROVIDER_TO_CONFIG_MAP = { SearchProviders.PERPLEXITY: PerplexitySearchConfig, diff --git a/litellm/vector_stores/main.py b/litellm/vector_stores/main.py index 01e06e1306..3cbbfea180 100644 --- a/litellm/vector_stores/main.py +++ b/litellm/vector_stores/main.py @@ -1,6 +1,7 @@ """ LiteLLM SDK Functions for Creating and Searching Vector Stores """ + import asyncio import contextvars from functools import partial @@ -10,6 +11,7 @@ import httpx import litellm from litellm.constants import request_timeout +from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler from litellm.types.router import GenericLiteLLMParams @@ -42,12 +44,12 @@ def mock_vector_store_search_response( content=[ VectorStoreResultContent( text="This is a sample search result from the vector store.", - type="text" + type="text", ) - ] + ], ) ] - + return VectorStoreSearchResponse( object="vector_store.search_results.page", search_query="sample query", @@ -79,7 +81,7 @@ def mock_vector_store_create_response( last_active_at=None, metadata=None, ) - + return mock_response @@ -166,14 +168,14 @@ def create( ) -> Union[VectorStoreCreateResponse, Coroutine[Any, Any, VectorStoreCreateResponse]]: """ Create a vector store. - + Args: name: The name of the vector store. file_ids: A list of File IDs that the vector store should use. expires_after: The expiration policy for the vector store. chunking_strategy: The chunking strategy used to chunk the file(s). metadata: Set of 16 key-value pairs that can be attached to an object. - + Returns: VectorStoreCreateResponse containing the created vector store details. """ @@ -198,9 +200,18 @@ def create( if custom_llm_provider is None: custom_llm_provider = "openai" + api_type, custom_llm_provider, _, _ = get_llm_provider( + model=custom_llm_provider, + custom_llm_provider=None, + litellm_params=None, + ) + # get provider config - using vector store custom logger for now - vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( - provider=litellm.LlmProviders(custom_llm_provider), + vector_store_provider_config = ( + ProviderConfigManager.get_provider_vector_stores_config( + provider=litellm.LlmProviders(custom_llm_provider), + api_type=api_type, + ) ) if vector_store_provider_config is None: @@ -209,7 +220,7 @@ def create( ) local_vars.update(kwargs) - + # Get VectorStoreCreateOptionalRequestParams with only valid parameters vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams = ( VectorStoreRequestUtils.get_requested_vector_store_create_optional_param( @@ -242,7 +253,7 @@ def create( _is_async=_is_async, client=kwargs.get("client"), ) - + return response except Exception as e: raise litellm.exception_type( @@ -340,7 +351,7 @@ def search( ) -> Union[VectorStoreSearchResponse, Coroutine[Any, Any, VectorStoreSearchResponse]]: """ Search a vector store for relevant chunks based on a query and file attributes filter. - + Args: vector_store_id: The ID of the vector store to search. query: A query string or array for the search. @@ -348,7 +359,7 @@ def search( max_num_results: Maximum number of results to return (1-50, default 10). ranking_options: Optional ranking options for search. rewrite_query: Whether to rewrite the natural language query for vector search. - + Returns: VectorStoreSearchResponse containing the search results. """ @@ -375,7 +386,7 @@ def search( pass # get llm provider logic - litellm_params = GenericLiteLLMParams(**kwargs) + litellm_params = GenericLiteLLMParams(vector_store_id=vector_store_id, **kwargs) ## MOCK RESPONSE LOGIC if litellm_params.mock_response and isinstance( @@ -390,9 +401,22 @@ def search( if custom_llm_provider is None: custom_llm_provider = "openai" + if "/" in custom_llm_provider: + api_type, custom_llm_provider, _, _ = get_llm_provider( + model=custom_llm_provider, + custom_llm_provider=None, + litellm_params=None, + ) + else: + api_type = None + custom_llm_provider = custom_llm_provider + # get provider config - using vector store custom logger for now - vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config( - provider=litellm.LlmProviders(custom_llm_provider), + vector_store_provider_config = ( + ProviderConfigManager.get_provider_vector_stores_config( + provider=litellm.LlmProviders(custom_llm_provider), + api_type=api_type, + ) ) if vector_store_provider_config is None: @@ -401,7 +425,7 @@ def search( ) local_vars.update(kwargs) - + # Get VectorStoreSearchOptionalRequestParams with only valid parameters vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams = ( VectorStoreRequestUtils.get_requested_vector_store_search_optional_param( @@ -438,7 +462,7 @@ def search( _is_async=_is_async, client=kwargs.get("client"), ) - + return response except Exception as e: raise litellm.exception_type( @@ -447,4 +471,4 @@ def search( original_exception=e, completion_kwargs=local_vars, extra_kwargs=kwargs, - ) \ No newline at end of file + ) diff --git a/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py new file mode 100644 index 0000000000..7cace33861 --- /dev/null +++ b/tests/vector_store_tests/test_vertex_ai_search_api_vector_store.py @@ -0,0 +1,194 @@ +""" +Test for Vertex AI Search API Vector Store with mocked responses +""" + +import json +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import litellm + + +# Mock response from actual Vertex AI Search API +MOCK_VERTEX_SEARCH_RESPONSE = { + "results": [ + { + "id": "0", + "document": { + "name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/0", + "id": "0", + "derivedStructData": { + "htmlTitle": "LiteLLM - Getting Started | liteLLM", + "snippets": [ + { + "htmlSnippet": "https://github.com/BerriAI/litellm.", + "snippet": "https://github.com/BerriAI/litellm.", + } + ], + "title": "LiteLLM - Getting Started | liteLLM", + "link": "https://docs.litellm.ai/docs/", + "displayLink": "docs.litellm.ai", + }, + }, + }, + { + "id": "1", + "document": { + "name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/1", + "id": "1", + "derivedStructData": { + "title": "Using Vector Stores (Knowledge Bases) | liteLLM", + "link": "https://docs.litellm.ai/docs/completion/knowledgebase", + "snippets": [ + { + "snippet": "LiteLLM integrates with vector stores, allowing your models to access your organization's data for more accurate and contextually relevant responses." + } + ], + }, + }, + }, + ], + "totalSize": 299, + "attributionToken": "mock_token", + "summary": {}, +} + + +class TestVertexAISearchAPIVectorStore: + """Test Vertex AI Search API Vector Store with mocked responses""" + + @pytest.mark.asyncio + async def test_basic_search_with_mock(self): + """Test basic vector search with mocked backend response""" + + # Mock the HTTP response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE + mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE) + + # Mock the access token method to avoid real authentication + with patch( + "litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token" + ) as mock_auth: + mock_auth.return_value = ("mock_token", "test-vector-store-db") + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = mock_response + + # Make the search request + response = await litellm.vector_stores.asearch( + query="what is LiteLLM?", + vector_store_id="test-litellm-app_1761094730750", + custom_llm_provider="vertex_ai/search_api", + vertex_project="test-vector-store-db", + vertex_location="us-central1", + ) + + print("Response:", json.dumps(response, indent=2, default=str)) + + # Validate the response structure (LiteLLM standard format) + assert response is not None + assert response["object"] == "vector_store.search_results.page" + assert "data" in response + assert len(response["data"]) > 0 + assert "search_query" in response + + # Validate first result + first_result = response["data"][0] + assert "score" in first_result + assert "content" in first_result + assert "file_id" in first_result + assert "filename" in first_result + assert "attributes" in first_result + + # Validate content structure + assert len(first_result["content"]) > 0 + assert first_result["content"][0]["type"] == "text" + assert "text" in first_result["content"][0] + + # Verify the API was called + mock_post.assert_called_once() + + # Verify the URL format + call_args = mock_post.call_args + url = call_args[1]["url"] if "url" in call_args[1] else call_args[0][0] + assert "discoveryengine.googleapis.com" in url + assert "test-vector-store-db" in url + assert "test-litellm-app_1761094730750" in url + + def test_basic_search_sync_with_mock(self): + """Test basic vector search (sync) with mocked backend response""" + + # Mock the HTTP response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE + mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE) + + # Mock the access token method to avoid real authentication + with patch( + "litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token" + ) as mock_auth: + mock_auth.return_value = ("mock_token", "test-vector-store-db") + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post" + ) as mock_post: + mock_post.return_value = mock_response + + # Make the search request + response = litellm.vector_stores.search( + query="what is LiteLLM?", + vector_store_id="test-litellm-app_1761094730750", + custom_llm_provider="vertex_ai/search_api", + vertex_project="test-vector-store-db", + vertex_location="us-central1", + ) + + print("Response:", json.dumps(response, indent=2, default=str)) + + # Validate the response structure (LiteLLM standard format) + assert response is not None + assert response["object"] == "vector_store.search_results.page" + assert "data" in response + assert len(response["data"]) > 0 + assert "search_query" in response + + # Validate first result structure + first_result = response["data"][0] + assert "score" in first_result + assert "content" in first_result + assert "file_id" in first_result + assert "filename" in first_result + assert "attributes" in first_result + + # Validate content structure + assert len(first_result["content"]) > 0 + assert first_result["content"][0]["type"] == "text" + assert "text" in first_result["content"][0] + + # Validate attributes + assert "document_id" in first_result["attributes"] + assert "link" in first_result["attributes"] + assert "title" in first_result["attributes"] + + # Verify the API was called + mock_post.assert_called_once() + + +if __name__ == "__main__": + # Run tests + import asyncio + + test = TestVertexAISearchAPIVectorStore() + + print("Running async test...") + asyncio.run(test.test_basic_search_with_mock()) + + print("\nRunning sync test...") + test.test_basic_search_sync_with_mock() + + print("\n✅ All tests passed!")