Use safe loading of creds (#16479)

This commit is contained in:
Sameer Kankute
2025-11-12 08:46:29 +05:30
committed by GitHub
parent 06efc7631d
commit 517eb0ee10
2 changed files with 54 additions and 5 deletions

View File

@@ -40,8 +40,8 @@ class VertexAIRerankConfig(BaseRerankConfig, VertexBase):
params = optional_params or {}
# Get credentials to extract project ID if needed
vertex_credentials = self.get_vertex_ai_credentials(params.copy())
vertex_project = self.get_vertex_ai_project(params.copy())
vertex_credentials = self.safe_get_vertex_ai_credentials(params.copy())
vertex_project = self.safe_get_vertex_ai_project(params.copy())
# Use _ensure_access_token to extract project_id from credentials
# This is the same method used in vertex embeddings
@@ -76,9 +76,9 @@ class VertexAIRerankConfig(BaseRerankConfig, VertexBase):
Validate and set up authentication for Vertex AI Discovery Engine API
"""
# Get credentials and project info from optional_params (which contains vertex_credentials, etc.)
litellm_params = optional_params or {}
vertex_credentials = self.get_vertex_ai_credentials(litellm_params)
vertex_project = self.get_vertex_ai_project(litellm_params)
litellm_params = optional_params.copy() if optional_params else {}
vertex_credentials = self.safe_get_vertex_ai_credentials(litellm_params)
vertex_project = self.safe_get_vertex_ai_project(litellm_params)
# Get access token using the base class method
access_token, project_id = self._ensure_access_token(

View File

@@ -449,3 +449,52 @@ class TestVertexAIRerankTransform:
"X-Goog-User-Project": "test-project-123"
}
assert headers == expected_headers
@patch('litellm.llms.vertex_ai.rerank.transformation.VertexAIRerankConfig._ensure_access_token')
def test_validate_environment_preserves_optional_params_for_get_complete_url(
self,
mock_ensure_access_token,
):
"""
Validate that calling validate_environment does not remove vertex-specific
parameters needed later by get_complete_url.
"""
mock_ensure_access_token.return_value = ("test-access-token", "project-from-token")
optional_params = {
"vertex_credentials": "path/to/credentials.json",
"vertex_project": "custom-project-id",
}
# Call validate_environment first this previously popped the values in-place
self.config.validate_environment(
headers={},
model=self.model,
api_key=None,
optional_params=optional_params,
)
# Ensure the original optional_params dict still retains the vertex keys
assert optional_params["vertex_credentials"] == "path/to/credentials.json"
assert optional_params["vertex_project"] == "custom-project-id"
# get_complete_url should still be able to access the vertex params
with patch('litellm.llms.vertex_ai.rerank.transformation.get_secret_str', return_value=None):
url = self.config.get_complete_url(
api_base=None,
model=self.model,
optional_params=optional_params,
)
expected_url = (
"https://discoveryengine.googleapis.com/v1/projects/project-from-token/"
"locations/global/rankingConfigs/default_ranking_config:rank"
)
assert url == expected_url
# _ensure_access_token should have been called twice with the same credentials
assert mock_ensure_access_token.call_count == 2
first_call = mock_ensure_access_token.call_args_list[0]
second_call = mock_ensure_access_token.call_args_list[1]
assert first_call.kwargs["credentials"] == "path/to/credentials.json"
assert second_call.kwargs["credentials"] == "path/to/credentials.json"