mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
(feat) Vector Stores: support Vertex AI Search API as vector store through LiteLLM (#15781)
* feat(vector_stores/): initial commit adding Vertex AI Search API support for litellm new vector store provider * feat(vector_store/): use vector store id for vertex ai search api * fix: transformation.py cleanup * fix: implement abstract function * fix: fix linting error * fix: main.py fix check
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -99,3 +99,4 @@ litellm/proxy/to_delete_loadtest_work/*
|
||||
update_model_cost_map.py
|
||||
tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py
|
||||
litellm/proxy/_experimental/out/guardrails/index.html
|
||||
scripts/test_vertex_ai_search.py
|
||||
|
||||
@@ -44,13 +44,8 @@ from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig, OCRResponse
|
||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.llms.base_llm.search.transformation import (
|
||||
BaseSearchConfig,
|
||||
SearchResponse,
|
||||
)
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
||||
BaseTextToSpeechConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.search.transformation import BaseSearchConfig, SearchResponse
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
@@ -1313,7 +1308,9 @@ class BaseLLMHTTPHandler:
|
||||
|
||||
# Data is always a dict for Mistral OCR format
|
||||
if not isinstance(transformed_result.data, dict):
|
||||
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
|
||||
raise ValueError(
|
||||
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
|
||||
)
|
||||
|
||||
data = transformed_result.data
|
||||
|
||||
@@ -1377,7 +1374,9 @@ class BaseLLMHTTPHandler:
|
||||
|
||||
# Data is always a dict for Mistral OCR format
|
||||
if not isinstance(transformed_result.data, dict):
|
||||
raise ValueError(f"Expected dict data for OCR request, got {type(transformed_result.data)}")
|
||||
raise ValueError(
|
||||
f"Expected dict data for OCR request, got {type(transformed_result.data)}"
|
||||
)
|
||||
|
||||
data = transformed_result.data
|
||||
|
||||
@@ -1549,7 +1548,6 @@ class BaseLLMHTTPHandler:
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Union[str, List[str]],
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .transformation import VertexVectorStoreConfig
|
||||
from .rag_api.transformation import VertexVectorStoreConfig
|
||||
from .search_api.transformation import VertexSearchAPIVectorStoreConfig
|
||||
|
||||
__all__ = ["VertexVectorStoreConfig"]
|
||||
__all__ = ["VertexVectorStoreConfig", "VertexSearchAPIVectorStoreConfig"]
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class VertexSearchAPIVectorStoreConfig(BaseVectorStoreConfig, VertexBase):
|
||||
"""
|
||||
Configuration for Vertex AI Search API Vector Store
|
||||
|
||||
This implementation uses the Vertex AI Search API for vector store operations.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up authentication for Vertex AI RAG API
|
||||
"""
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
|
||||
# Get credentials and project info
|
||||
vertex_credentials = self.get_vertex_ai_credentials(dict(litellm_params))
|
||||
vertex_project = self.get_vertex_ai_project(dict(litellm_params))
|
||||
|
||||
# Get access token using the base class method
|
||||
access_token, project_id = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the Base endpoint for Vertex AI Search API
|
||||
"""
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
engine_id = litellm_params.get("vector_store_id")
|
||||
collection_id = (
|
||||
litellm_params.get("vertex_collection_id") or "default_collection"
|
||||
)
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
|
||||
# Vertex AI Search API endpoint for search
|
||||
return (
|
||||
f"https://discoveryengine.googleapis.com/v1/"
|
||||
f"projects/{vertex_project}/locations/{vertex_location}/"
|
||||
f"collections/{collection_id}/engines/{engine_id}/servingConfigs/default_config"
|
||||
)
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform search request for Vertex AI RAG API
|
||||
"""
|
||||
# Convert query to string if it's a list
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
# Vertex AI RAG API endpoint for retrieving contexts
|
||||
url = f"{api_base}:search"
|
||||
|
||||
# Construct full rag corpus path
|
||||
# Build the request body for Vertex AI Search API
|
||||
request_body = {"query": query, "pageSize": 10}
|
||||
|
||||
#########################################################
|
||||
# Update logging object with details of the request
|
||||
#########################################################
|
||||
litellm_logging_obj.model_call_details["query"] = query
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform Vertex AI Search API response to standard vector store search response
|
||||
|
||||
Handles the format from Discovery Engine Search API which returns:
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"id": "...",
|
||||
"document": {
|
||||
"derivedStructData": {
|
||||
"title": "...",
|
||||
"link": "...",
|
||||
"snippets": [...]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_json = response.json()
|
||||
|
||||
# Extract results from Vertex AI Search API response
|
||||
results = response_json.get("results", [])
|
||||
|
||||
# Transform results to standard format
|
||||
search_results: List[VectorStoreSearchResult] = []
|
||||
for result in results:
|
||||
document = result.get("document", {})
|
||||
derived_data = document.get("derivedStructData", {})
|
||||
|
||||
# Extract text content from snippets
|
||||
snippets = derived_data.get("snippets", [])
|
||||
text_content = ""
|
||||
|
||||
if snippets:
|
||||
# Combine all snippets into one text
|
||||
text_parts = [
|
||||
snippet.get("snippet", snippet.get("htmlSnippet", ""))
|
||||
for snippet in snippets
|
||||
]
|
||||
text_content = " ".join(text_parts)
|
||||
|
||||
# If no snippets, use title as fallback
|
||||
if not text_content:
|
||||
text_content = derived_data.get("title", "")
|
||||
|
||||
content = [
|
||||
VectorStoreResultContent(
|
||||
text=text_content,
|
||||
type="text",
|
||||
)
|
||||
]
|
||||
|
||||
# Extract file/document information
|
||||
document_link = derived_data.get("link", "")
|
||||
document_title = derived_data.get("title", "")
|
||||
document_id = result.get("id", "")
|
||||
|
||||
# Use link as file_id if available, otherwise use document ID
|
||||
file_id = document_link if document_link else document_id
|
||||
filename = document_title if document_title else "Unknown Document"
|
||||
|
||||
# Build attributes with available metadata
|
||||
attributes = {
|
||||
"document_id": document_id,
|
||||
}
|
||||
|
||||
if document_link:
|
||||
attributes["link"] = document_link
|
||||
if document_title:
|
||||
attributes["title"] = document_title
|
||||
|
||||
# Add display link if available
|
||||
display_link = derived_data.get("displayLink", "")
|
||||
if display_link:
|
||||
attributes["displayLink"] = display_link
|
||||
|
||||
# Add formatted URL if available
|
||||
formatted_url = derived_data.get("formattedUrl", "")
|
||||
if formatted_url:
|
||||
attributes["formattedUrl"] = formatted_url
|
||||
|
||||
# Note: Search API doesn't provide explicit scores in the response
|
||||
# You can use the position/rank as an implicit score
|
||||
score = 1.0 / (
|
||||
float(search_results.__len__() + 1)
|
||||
) # Decreasing score based on position
|
||||
|
||||
result_obj = VectorStoreSearchResult(
|
||||
score=score,
|
||||
content=content,
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
attributes=attributes,
|
||||
)
|
||||
search_results.append(result_obj)
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=litellm_logging_obj.model_call_details.get("query", ""),
|
||||
data=search_results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
raise NotImplementedError
|
||||
@@ -3,6 +3,16 @@ model_list:
|
||||
litellm_params:
|
||||
model: bedrock/global.anthropic.claude-sonnet-4-5-20250929-v1:0
|
||||
|
||||
vector_store_registry:
|
||||
- vector_store_name: "vertex-ai-litellm-website-knowledgebase"
|
||||
litellm_params:
|
||||
vector_store_id: "test-litellm-app_1761094730750"
|
||||
custom_llm_provider: "vertex_ai/search_api"
|
||||
vertex_project: "test-litellm-app"
|
||||
vertex_location: "us-central1"
|
||||
vector_store_description: "Vertex AI vector store for the Litellm website knowledgebase"
|
||||
vector_store_metadata:
|
||||
source: "https://www.litellm.com/docs"
|
||||
mcp_servers:
|
||||
github_mcp:
|
||||
url: "https://api.githubcopilot.com/mcp"
|
||||
|
||||
@@ -163,9 +163,6 @@ class CredentialLiteLLMParams(BaseModel):
|
||||
watsonx_region_name: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
"""
|
||||
LiteLLM Params without 'model' arg (used across completion / assistants api)
|
||||
@@ -210,6 +207,9 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
s3_bucket_name: Optional[str] = None
|
||||
gcs_bucket_name: Optional[str] = None
|
||||
|
||||
# Vector Store Params
|
||||
vector_store_id: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
|
||||
@@ -145,9 +145,7 @@ from litellm.llms.base_llm.google_genai.transformation import (
|
||||
)
|
||||
from litellm.llms.base_llm.ocr.transformation import BaseOCRConfig
|
||||
from litellm.llms.base_llm.search.transformation import BaseSearchConfig
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
||||
BaseTextToSpeechConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import BaseTextToSpeechConfig
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.mistral.ocr.transformation import MistralOCRConfig
|
||||
@@ -914,6 +912,7 @@ def _get_wrapper_timeout(
|
||||
|
||||
return timeout
|
||||
|
||||
|
||||
def check_coroutine(value) -> bool:
|
||||
return get_coroutine_checker().is_async_callable(value)
|
||||
|
||||
@@ -993,9 +992,7 @@ def post_call_processing(
|
||||
].message.content # type: ignore
|
||||
if model_response is not None:
|
||||
### POST-CALL RULES ###
|
||||
rules_obj.post_call_rules(
|
||||
input=model_response, model=model
|
||||
)
|
||||
rules_obj.post_call_rules(input=model_response, model=model)
|
||||
### JSON SCHEMA VALIDATION ###
|
||||
if litellm.enable_json_schema_validation is True:
|
||||
try:
|
||||
@@ -1011,9 +1008,9 @@ def post_call_processing(
|
||||
optional_params["response_format"],
|
||||
dict,
|
||||
)
|
||||
and optional_params[
|
||||
"response_format"
|
||||
].get("json_schema")
|
||||
and optional_params["response_format"].get(
|
||||
"json_schema"
|
||||
)
|
||||
is not None
|
||||
):
|
||||
json_response_format = optional_params[
|
||||
@@ -1041,9 +1038,7 @@ def post_call_processing(
|
||||
if (
|
||||
optional_params is not None
|
||||
and "response_format" in optional_params
|
||||
and isinstance(
|
||||
optional_params["response_format"], dict
|
||||
)
|
||||
and isinstance(optional_params["response_format"], dict)
|
||||
and "type" in optional_params["response_format"]
|
||||
and optional_params["response_format"]["type"]
|
||||
== "json_object"
|
||||
@@ -1077,7 +1072,6 @@ def post_call_processing(
|
||||
def client(original_function): # noqa: PLR0915
|
||||
rules_obj = Rules()
|
||||
|
||||
|
||||
@wraps(original_function)
|
||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||
# DO NOT MOVE THIS. It always needs to run first
|
||||
@@ -5000,7 +4994,9 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||
tpm=_model_info.get("tpm", None),
|
||||
rpm=_model_info.get("rpm", None),
|
||||
ocr_cost_per_page=_model_info.get("ocr_cost_per_page", None),
|
||||
annotation_cost_per_page=_model_info.get("annotation_cost_per_page", None),
|
||||
annotation_cost_per_page=_model_info.get(
|
||||
"annotation_cost_per_page", None
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting model info: {e}")
|
||||
@@ -7226,6 +7222,7 @@ class ProviderConfigManager:
|
||||
from litellm.llms.sagemaker.embedding.transformation import (
|
||||
SagemakerEmbeddingConfig,
|
||||
)
|
||||
|
||||
return SagemakerEmbeddingConfig.get_model_config(model)
|
||||
return None
|
||||
|
||||
@@ -7461,6 +7458,7 @@ class ProviderConfigManager:
|
||||
@staticmethod
|
||||
def get_provider_vector_stores_config(
|
||||
provider: LlmProviders,
|
||||
api_type: Optional[str] = None,
|
||||
) -> Optional[BaseVectorStoreConfig]:
|
||||
"""
|
||||
v2 vector store config, use this for new vector store integrations
|
||||
@@ -7478,11 +7476,18 @@ class ProviderConfigManager:
|
||||
|
||||
return AzureOpenAIVectorStoreConfig()
|
||||
elif litellm.LlmProviders.VERTEX_AI == provider:
|
||||
from litellm.llms.vertex_ai.vector_stores.transformation import (
|
||||
if api_type == "rag_api" or api_type is None: # default to rag_api
|
||||
from litellm.llms.vertex_ai.vector_stores.rag_api.transformation import (
|
||||
VertexVectorStoreConfig,
|
||||
)
|
||||
|
||||
return VertexVectorStoreConfig()
|
||||
elif api_type == "search_api":
|
||||
from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
|
||||
VertexSearchAPIVectorStoreConfig,
|
||||
)
|
||||
|
||||
return VertexSearchAPIVectorStoreConfig()
|
||||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
from litellm.llms.bedrock.vector_stores.transformation import (
|
||||
BedrockVectorStoreConfig,
|
||||
@@ -7640,12 +7645,8 @@ class ProviderConfigManager:
|
||||
from litellm.llms.parallel_ai.search.transformation import (
|
||||
ParallelAISearchConfig,
|
||||
)
|
||||
from litellm.llms.perplexity.search.transformation import (
|
||||
PerplexitySearchConfig,
|
||||
)
|
||||
from litellm.llms.tavily.search.transformation import (
|
||||
TavilySearchConfig,
|
||||
)
|
||||
from litellm.llms.perplexity.search.transformation import PerplexitySearchConfig
|
||||
from litellm.llms.tavily.search.transformation import TavilySearchConfig
|
||||
|
||||
PROVIDER_TO_CONFIG_MAP = {
|
||||
SearchProviders.PERPLEXITY: PerplexitySearchConfig,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
LiteLLM SDK Functions for Creating and Searching Vector Stores
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
@@ -10,6 +11,7 @@ import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
@@ -42,9 +44,9 @@ def mock_vector_store_search_response(
|
||||
content=[
|
||||
VectorStoreResultContent(
|
||||
text="This is a sample search result from the vector store.",
|
||||
type="text"
|
||||
type="text",
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
]
|
||||
|
||||
@@ -198,9 +200,18 @@ def create(
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
api_type, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=custom_llm_provider,
|
||||
custom_llm_provider=None,
|
||||
litellm_params=None,
|
||||
)
|
||||
|
||||
# get provider config - using vector store custom logger for now
|
||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
||||
vector_store_provider_config = (
|
||||
ProviderConfigManager.get_provider_vector_stores_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
api_type=api_type,
|
||||
)
|
||||
)
|
||||
|
||||
if vector_store_provider_config is None:
|
||||
@@ -375,7 +386,7 @@ def search(
|
||||
pass
|
||||
|
||||
# get llm provider logic
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params = GenericLiteLLMParams(vector_store_id=vector_store_id, **kwargs)
|
||||
|
||||
## MOCK RESPONSE LOGIC
|
||||
if litellm_params.mock_response and isinstance(
|
||||
@@ -390,9 +401,22 @@ def search(
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
if "/" in custom_llm_provider:
|
||||
api_type, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model=custom_llm_provider,
|
||||
custom_llm_provider=None,
|
||||
litellm_params=None,
|
||||
)
|
||||
else:
|
||||
api_type = None
|
||||
custom_llm_provider = custom_llm_provider
|
||||
|
||||
# get provider config - using vector store custom logger for now
|
||||
vector_store_provider_config = ProviderConfigManager.get_provider_vector_stores_config(
|
||||
vector_store_provider_config = (
|
||||
ProviderConfigManager.get_provider_vector_stores_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
api_type=api_type,
|
||||
)
|
||||
)
|
||||
|
||||
if vector_store_provider_config is None:
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
Test for Vertex AI Search API Vector Store with mocked responses
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import litellm
|
||||
|
||||
|
||||
# Mock response from actual Vertex AI Search API
|
||||
MOCK_VERTEX_SEARCH_RESPONSE = {
|
||||
"results": [
|
||||
{
|
||||
"id": "0",
|
||||
"document": {
|
||||
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/0",
|
||||
"id": "0",
|
||||
"derivedStructData": {
|
||||
"htmlTitle": "<b>LiteLLM</b> - Getting Started | <b>liteLLM</b>",
|
||||
"snippets": [
|
||||
{
|
||||
"htmlSnippet": "https://github.com/BerriAI/<b>litellm</b>.",
|
||||
"snippet": "https://github.com/BerriAI/litellm.",
|
||||
}
|
||||
],
|
||||
"title": "LiteLLM - Getting Started | liteLLM",
|
||||
"link": "https://docs.litellm.ai/docs/",
|
||||
"displayLink": "docs.litellm.ai",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "1",
|
||||
"document": {
|
||||
"name": "projects/648660250433/locations/global/collections/default_collection/dataStores/litellm-docs_1761094140318/branches/0/documents/1",
|
||||
"id": "1",
|
||||
"derivedStructData": {
|
||||
"title": "Using Vector Stores (Knowledge Bases) | liteLLM",
|
||||
"link": "https://docs.litellm.ai/docs/completion/knowledgebase",
|
||||
"snippets": [
|
||||
{
|
||||
"snippet": "LiteLLM integrates with vector stores, allowing your models to access your organization's data for more accurate and contextually relevant responses."
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
"totalSize": 299,
|
||||
"attributionToken": "mock_token",
|
||||
"summary": {},
|
||||
}
|
||||
|
||||
|
||||
class TestVertexAISearchAPIVectorStore:
|
||||
"""Test Vertex AI Search API Vector Store with mocked responses"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_search_with_mock(self):
|
||||
"""Test basic vector search with mocked backend response"""
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
|
||||
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
|
||||
|
||||
# Mock the access token method to avoid real authentication
|
||||
with patch(
|
||||
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = ("mock_token", "test-vector-store-db")
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Make the search request
|
||||
response = await litellm.vector_stores.asearch(
|
||||
query="what is LiteLLM?",
|
||||
vector_store_id="test-litellm-app_1761094730750",
|
||||
custom_llm_provider="vertex_ai/search_api",
|
||||
vertex_project="test-vector-store-db",
|
||||
vertex_location="us-central1",
|
||||
)
|
||||
|
||||
print("Response:", json.dumps(response, indent=2, default=str))
|
||||
|
||||
# Validate the response structure (LiteLLM standard format)
|
||||
assert response is not None
|
||||
assert response["object"] == "vector_store.search_results.page"
|
||||
assert "data" in response
|
||||
assert len(response["data"]) > 0
|
||||
assert "search_query" in response
|
||||
|
||||
# Validate first result
|
||||
first_result = response["data"][0]
|
||||
assert "score" in first_result
|
||||
assert "content" in first_result
|
||||
assert "file_id" in first_result
|
||||
assert "filename" in first_result
|
||||
assert "attributes" in first_result
|
||||
|
||||
# Validate content structure
|
||||
assert len(first_result["content"]) > 0
|
||||
assert first_result["content"][0]["type"] == "text"
|
||||
assert "text" in first_result["content"][0]
|
||||
|
||||
# Verify the API was called
|
||||
mock_post.assert_called_once()
|
||||
|
||||
# Verify the URL format
|
||||
call_args = mock_post.call_args
|
||||
url = call_args[1]["url"] if "url" in call_args[1] else call_args[0][0]
|
||||
assert "discoveryengine.googleapis.com" in url
|
||||
assert "test-vector-store-db" in url
|
||||
assert "test-litellm-app_1761094730750" in url
|
||||
|
||||
def test_basic_search_sync_with_mock(self):
|
||||
"""Test basic vector search (sync) with mocked backend response"""
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = MOCK_VERTEX_SEARCH_RESPONSE
|
||||
mock_response.text = json.dumps(MOCK_VERTEX_SEARCH_RESPONSE)
|
||||
|
||||
# Mock the access token method to avoid real authentication
|
||||
with patch(
|
||||
"litellm.llms.vertex_ai.vector_stores.search_api.transformation.VertexSearchAPIVectorStoreConfig._ensure_access_token"
|
||||
) as mock_auth:
|
||||
mock_auth.return_value = ("mock_token", "test-vector-store-db")
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post"
|
||||
) as mock_post:
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Make the search request
|
||||
response = litellm.vector_stores.search(
|
||||
query="what is LiteLLM?",
|
||||
vector_store_id="test-litellm-app_1761094730750",
|
||||
custom_llm_provider="vertex_ai/search_api",
|
||||
vertex_project="test-vector-store-db",
|
||||
vertex_location="us-central1",
|
||||
)
|
||||
|
||||
print("Response:", json.dumps(response, indent=2, default=str))
|
||||
|
||||
# Validate the response structure (LiteLLM standard format)
|
||||
assert response is not None
|
||||
assert response["object"] == "vector_store.search_results.page"
|
||||
assert "data" in response
|
||||
assert len(response["data"]) > 0
|
||||
assert "search_query" in response
|
||||
|
||||
# Validate first result structure
|
||||
first_result = response["data"][0]
|
||||
assert "score" in first_result
|
||||
assert "content" in first_result
|
||||
assert "file_id" in first_result
|
||||
assert "filename" in first_result
|
||||
assert "attributes" in first_result
|
||||
|
||||
# Validate content structure
|
||||
assert len(first_result["content"]) > 0
|
||||
assert first_result["content"][0]["type"] == "text"
|
||||
assert "text" in first_result["content"][0]
|
||||
|
||||
# Validate attributes
|
||||
assert "document_id" in first_result["attributes"]
|
||||
assert "link" in first_result["attributes"]
|
||||
assert "title" in first_result["attributes"]
|
||||
|
||||
# Verify the API was called
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
import asyncio
|
||||
|
||||
test = TestVertexAISearchAPIVectorStore()
|
||||
|
||||
print("Running async test...")
|
||||
asyncio.run(test.test_basic_search_with_mock())
|
||||
|
||||
print("\nRunning sync test...")
|
||||
test.test_basic_search_sync_with_mock()
|
||||
|
||||
print("\n✅ All tests passed!")
|
||||
Reference in New Issue
Block a user