mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
Merge pull request #17542 from BerriAI/litellm_pcs_vertex_fix
fix failing vertex tests
This commit is contained in:
@@ -1619,7 +1619,8 @@ response = completion(
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
api_base="http://10.96.32.8", # Your PSC endpoint
|
||||
vertex_project="my-project-id",
|
||||
vertex_location="us-central1"
|
||||
vertex_location="us-central1",
|
||||
use_psc_endpoint_format=True
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1642,6 +1643,7 @@ model_list:
|
||||
vertex_project: "my-project-id"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "/path/to/service_account.json"
|
||||
use_psc_endpoint_format: True
|
||||
- model_name: psc-embedding
|
||||
litellm_params:
|
||||
model: vertex_ai/text-embedding-004
|
||||
@@ -1649,6 +1651,7 @@ model_list:
|
||||
vertex_project: "my-project-id"
|
||||
vertex_location: "us-central1"
|
||||
vertex_credentials: "/path/to/service_account.json"
|
||||
use_psc_endpoint_format: True
|
||||
```
|
||||
|
||||
## Fine-tuned Models
|
||||
|
||||
@@ -2123,6 +2123,9 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
@@ -2134,6 +2137,7 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
headers = VertexGeminiConfig().validate_environment(
|
||||
@@ -2217,6 +2221,9 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
@@ -2228,6 +2235,7 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
headers = VertexGeminiConfig().validate_environment(
|
||||
@@ -2401,6 +2409,9 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, url = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
@@ -2412,6 +2423,7 @@ class VertexLLM(VertexBase):
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
headers = VertexGeminiConfig().validate_environment(
|
||||
api_key=auth_header,
|
||||
|
||||
@@ -72,6 +72,9 @@ class VertexEmbedding(VertexBase):
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
@@ -84,6 +87,7 @@ class VertexEmbedding(VertexBase):
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
mode="embedding",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = (
|
||||
@@ -164,6 +168,9 @@ class VertexEmbedding(VertexBase):
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# Extract use_psc_endpoint_format from optional_params
|
||||
use_psc_endpoint_format = optional_params.get("use_psc_endpoint_format", False)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
@@ -176,6 +183,7 @@ class VertexEmbedding(VertexBase):
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||
mode="embedding",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
|
||||
vertex_request: VertexEmbeddingRequest = (
|
||||
|
||||
@@ -296,6 +296,7 @@ class VertexBase:
|
||||
vertex_project: Optional[str] = None,
|
||||
vertex_location: Optional[str] = None,
|
||||
vertex_api_version: Optional[Literal["v1", "v1beta1"]] = None,
|
||||
use_psc_endpoint_format: bool = False,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
|
||||
@@ -305,6 +306,11 @@ class VertexBase:
|
||||
2. Vertex AI with standard proxies - constructs {api_base}:{endpoint}
|
||||
3. Vertex AI with PSC endpoints - constructs full path structure
|
||||
{api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
|
||||
(only when use_psc_endpoint_format=True)
|
||||
|
||||
Args:
|
||||
use_psc_endpoint_format: If True, constructs PSC endpoint URL format.
|
||||
If False (default), uses api_base as-is and appends :{endpoint}
|
||||
|
||||
## Returns
|
||||
- (auth_header, url) - Tuple[Optional[str], str]
|
||||
@@ -325,33 +331,25 @@ class VertexBase:
|
||||
auth_header = {"x-goog-api-key": gemini_api_key} # type: ignore[assignment]
|
||||
else:
|
||||
# For Vertex AI
|
||||
# Check if this is a PSC endpoint or custom deployment
|
||||
# PSC/custom endpoints need the full path structure
|
||||
if vertex_project and vertex_location and model:
|
||||
if use_psc_endpoint_format:
|
||||
# User explicitly specified PSC endpoint format
|
||||
# Construct full PSC/custom endpoint URL
|
||||
if not (vertex_project and vertex_location and model):
|
||||
raise ValueError(
|
||||
"vertex_project, vertex_location, and model are required when use_psc_endpoint_format=True"
|
||||
)
|
||||
# Strip routing prefixes (bge/, gemma/, etc.) for endpoint URL construction
|
||||
model_for_url = get_vertex_base_model_name(model=model)
|
||||
|
||||
# Check if model is numeric (endpoint ID) or if api_base doesn't contain googleapis.com
|
||||
# These are indicators of PSC/custom endpoints
|
||||
is_psc_or_custom = (
|
||||
"googleapis.com" not in api_base.lower() or model_for_url.isdigit()
|
||||
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
|
||||
version = vertex_api_version or "v1"
|
||||
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
|
||||
api_base.rstrip("/"),
|
||||
version,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
model_for_url,
|
||||
endpoint,
|
||||
)
|
||||
|
||||
if is_psc_or_custom:
|
||||
# Construct full PSC/custom endpoint URL
|
||||
# Format: {api_base}/v1/projects/{project}/locations/{location}/endpoints/{model}:{endpoint}
|
||||
version = vertex_api_version or "v1"
|
||||
url = "{}/{}/projects/{}/locations/{}/endpoints/{}:{}".format(
|
||||
api_base.rstrip("/"),
|
||||
version,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
model_for_url,
|
||||
endpoint,
|
||||
)
|
||||
else:
|
||||
# Standard proxy - just append endpoint
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
else:
|
||||
# Fallback to simple format if we don't have all parameters
|
||||
url = "{}:{}".format(api_base, endpoint)
|
||||
@@ -372,6 +370,7 @@ class VertexBase:
|
||||
api_base: Optional[str],
|
||||
should_use_v1beta1_features: Optional[bool] = False,
|
||||
mode: all_gemini_url_modes = "chat",
|
||||
use_psc_endpoint_format: bool = False,
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
@@ -421,6 +420,7 @@ class VertexBase:
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version=version,
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
def _handle_reauthentication(
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
RAGFlow chat transformation tests.
|
||||
"""
|
||||
|
||||
@@ -214,7 +214,8 @@ def test_vertex_ai_bge_psc_endpoint_url_construction():
|
||||
api_base="http://10.128.16.2",
|
||||
vertex_project="gen-lang-client-0682925754",
|
||||
vertex_location="us-central1",
|
||||
client=client
|
||||
client=client,
|
||||
use_psc_endpoint_format=True # Enable PSC endpoint format for this test
|
||||
)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@@ -26,6 +26,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "1234567890"
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
use_psc_endpoint_format = True
|
||||
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
@@ -39,6 +40,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
|
||||
@@ -53,7 +55,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "1234567890"
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -66,6 +68,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:streamGenerateContent?alt=sse"
|
||||
@@ -80,7 +83,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "1234567890"
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -93,6 +96,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1beta1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
expected_url = f"{psc_api_base}/v1beta1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
|
||||
@@ -107,7 +111,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "1234567890"
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -120,6 +124,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
expected_url = f"{psc_api_base}/v1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}:predict"
|
||||
@@ -134,7 +139,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "1234567890"
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -147,6 +152,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
# rstrip('/') should remove the trailing slash
|
||||
@@ -162,7 +168,6 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "gemini-pro" # Not numeric
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=proxy_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -190,7 +195,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
endpoint_id = "9876543210" # Numeric endpoint ID
|
||||
project_id = "test-project"
|
||||
location = "us-central1"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=proxy_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -203,6 +208,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
# Numeric model should trigger full path construction
|
||||
@@ -215,7 +221,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
"""Test that when api_base is None, the original URL is returned"""
|
||||
vertex_base = VertexBase()
|
||||
original_url = "https://us-central1-aiplatform.googleapis.com/v1/projects/test/locations/us-central1/publishers/google/models/gemini-pro:generateContent"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=None,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -228,6 +234,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project="test-project",
|
||||
vertex_location="us-central1",
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
# When api_base is None, original URL should be returned unchanged
|
||||
@@ -238,7 +245,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_base = VertexBase()
|
||||
psc_api_base = "http://10.96.32.8"
|
||||
test_auth_header = "Bearer test-token-12345"
|
||||
|
||||
use_psc_endpoint_format = True
|
||||
auth_header, url = vertex_base._check_custom_proxy(
|
||||
api_base=psc_api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
@@ -251,6 +258,7 @@ class TestVertexAIPSCEndpointSupport:
|
||||
vertex_project="test-project",
|
||||
vertex_location="us-central1",
|
||||
vertex_api_version="v1",
|
||||
use_psc_endpoint_format=use_psc_endpoint_format,
|
||||
)
|
||||
|
||||
assert (
|
||||
|
||||
Reference in New Issue
Block a user