Merge pull request #17542 from BerriAI/litellm_pcs_vertex_fix

fix failing vertex tests
This commit is contained in:
Sameer Kankute
2025-12-06 01:15:59 +05:30
committed by GitHub
7 changed files with 66 additions and 38 deletions

View File

@@ -1619,7 +1619,8 @@ response = completion(
messages=[{"role": "user", "content": "Hello!"}],
api_base="http://10.96.32.8", # Your PSC endpoint
vertex_project="my-project-id",
vertex_location="us-central1"
vertex_location="us-central1",
use_psc_endpoint_format=True
)
```
@@ -1642,6 +1643,7 @@ model_list:
vertex_project: "my-project-id"
vertex_location: "us-central1"
vertex_credentials: "/path/to/service_account.json"
use_psc_endpoint_format: True
- model_name: psc-embedding
litellm_params:
model: vertex_ai/text-embedding-004
@@ -1649,6 +1651,7 @@ model_list:
vertex_project: "my-project-id"
vertex_location: "us-central1"
vertex_credentials: "/path/to/service_account.json"
use_psc_endpoint_format: True
```
## Fine-tuned Models

View File

@@ -2123,6 +2123,9 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@@ -2134,6 +2137,7 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = VertexGeminiConfig().validate_environment(
@@ -2217,6 +2221,9 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@@ -2228,6 +2235,7 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = VertexGeminiConfig().validate_environment(
@@ -2401,6 +2409,9 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@@ -2412,6 +2423,7 @@ class VertexLLM(VertexBase):
custom_llm_provider=custom_llm_provider,
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = VertexGeminiConfig().validate_environment(
api_key=auth_header,

View File

@@ -72,6 +72,9 @@ class VertexEmbedding(VertexBase):
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@@ -84,6 +87,7 @@ class VertexEmbedding(VertexBase):
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (
@@ -164,6 +168,9 @@ class VertexEmbedding(VertexBase):
project_id=vertex_project,
custom_llm_provider=custom_llm_provider,
)
# Extract use_psc_endpoint_format from optional_params
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
auth_header, api_base = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@@ -176,6 +183,7 @@ class VertexEmbedding(VertexBase):
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
mode="embedding",
use_psc_endpoint_format=use_psc_endpoint_format,
)
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
vertex_request: VertexEmbeddingRequest = (

View File

@@ -296,6 +296,7 @@ class VertexBase:
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_api_version: Optional[Literal["v1", "v1beta1"]] = None,
use_psc_endpoint_format: bool = False,
) -> Tuple[Optional[str], str]:
"""
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
@@ -305,6 +306,11 @@ class VertexBase:
2. Vertex AI with standard proxies - constructs {api_base}:{endpoint}
3. Vertex AI with PSC endpoints - constructs full path structure
{api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
(only when use_psc_endpoint_format=True)
Args:
use_psc_endpoint_format: If True, constructs PSC endpoint URL format.
If False (default), uses api_base as-is and appends :{endpoint}
## Returns
- (auth_header, url) - Tuple[Optional[str], str]
@@ -325,33 +331,25 @@ class VertexBase:
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
else:
# For Vertex AI
# Check if this is a PSC endpoint or custom deployment
# PSC/custom endpoints need the full path structure
if vertex_project and vertex_location and model:
if use_psc_endpoint_format:
# User explicitly specified PSC endpoint format
# Construct full PSC/custom endpoint URL
if not (vertex_project and vertex_location and model):
raise ValueError(
"vertex_project, vertex_location, and model are required when use_psc_endpoint_format=True"
)
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
model_for_url = get_vertex_base_model_name(model=model)
# Check if model is numeric (endpoint ID) or if api_base doesn't contain googleapis.com
# These are indicators of PSC/custom endpoints
is_psc_or_custom = (
"googleapis.com" not in api_base.lower() or model_for_url.isdigit()
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
version = vertex_api_version or "v1"
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
api_base.rstrip("/"),
version,
vertex_project,
vertex_location,
model_for_url,
endpoint,
)
if is_psc_or_custom:
# Construct full PSC/custom endpoint URL
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
version = vertex_api_version or "v1"
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
api_base.rstrip("/"),
version,
vertex_project,
vertex_location,
model_for_url,
endpoint,
)
else:
# Standard proxy - just append endpoint
url = "{}:{}".format(api_base, endpoint)
else:
# Fallback to simple format if we don't have all parameters
url = "{}:{}".format(api_base, endpoint)
@@ -372,6 +370,7 @@ class VertexBase:
api_base: Optional[str],
should_use_v1beta1_features: Optional[bool] = False,
mode: all_gemini_url_modes = "chat",
use_psc_endpoint_format: bool = False,
) -> Tuple[Optional[str], str]:
"""
Internal function. Returns the token and url for the call.
@@ -421,6 +420,7 @@ class VertexBase:
vertex_project=vertex_project,
vertex_location=vertex_location,
vertex_api_version=version,
use_psc_endpoint_format=use_psc_endpoint_format,
)
def _handle_reauthentication(

View File

@@ -1,4 +0,0 @@
"""
RAGFlow chat transformation tests.
"""

View File

@@ -214,7 +214,8 @@ def test_vertex_ai_bge_psc_endpoint_url_construction():
api_base="http://10.128.16.2",
vertex_project="gen-lang-client-0682925754",
vertex_location="us-central1",
client=client
client=client,
use_psc_endpoint_format=True # Enable PSC endpoint format for this test
)
mock_post.assert_called_once()

View File

@@ -26,6 +26,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "1234567890"
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
@@ -39,6 +40,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
@@ -53,7 +55,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "1234567890"
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
custom_llm_provider="vertex_ai",
@@ -66,6 +68,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:streamGenerateContent?alt=sse"
@@ -80,7 +83,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "1234567890"
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
custom_llm_provider="vertex_ai",
@@ -93,6 +96,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1beta1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
expected_url = f"{psc_api_base}/v1beta1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
@@ -107,7 +111,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "1234567890"
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
custom_llm_provider="vertex_ai",
@@ -120,6 +124,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
@@ -134,7 +139,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "1234567890"
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
custom_llm_provider="vertex_ai",
@@ -147,6 +152,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
# rstrip('/') should remove the trailing slash
@@ -162,7 +168,6 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "gemini-pro" # Not numeric
project_id = "test-project"
location = "us-central1"
auth_header, url = vertex_base._check_custom_proxy(
api_base=proxy_api_base,
custom_llm_provider="vertex_ai",
@@ -190,7 +195,7 @@ class TestVertexAIPSCEndpointSupport:
endpoint_id = "9876543210" # Numeric endpoint ID
project_id = "test-project"
location = "us-central1"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=proxy_api_base,
custom_llm_provider="vertex_ai",
@@ -203,6 +208,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project=project_id,
vertex_location=location,
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
# Numeric model should trigger full path construction
@@ -215,7 +221,7 @@ class TestVertexAIPSCEndpointSupport:
"""Test that when api_base is None, the original URL is returned"""
vertex_base = VertexBase()
original_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=None,
custom_llm_provider="vertex_ai",
@@ -228,6 +234,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project="test-project",
vertex_location="us-central1",
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
# When api_base is None, original URL should be returned unchanged
@@ -238,7 +245,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_base = VertexBase()
psc_api_base = "http://10.96.32.8"
test_auth_header = "Bearer test-token-12345"
use_psc_endpoint_format = True
auth_header, url = vertex_base._check_custom_proxy(
api_base=psc_api_base,
custom_llm_provider="vertex_ai",
@@ -251,6 +258,7 @@ class TestVertexAIPSCEndpointSupport:
vertex_project="test-project",
vertex_location="us-central1",
vertex_api_version="v1",
use_psc_endpoint_format=use_psc_endpoint_format,
)
assert (