mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Fix] Bedrock Embeddings - Ensure correct aws_region is used when provided dynamically (#16547)
* test_bedrock_embedding_uses_correct_region_when_specified * fix aws_region_name in bedrock embeddings
This commit is contained in:
@@ -391,7 +391,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
||||
) # default to model if not passed
|
||||
modelId = urllib.parse.quote(unencoded_model_id, safe="")
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
optional_params={"aws_region_name": aws_region_name},
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
|
||||
@@ -336,3 +336,116 @@ async def test_e2e_bedrock_async_invoke_embedding_async_twelvelabs_marengo():
|
||||
# Restore original region name
|
||||
if original_region_name:
|
||||
os.environ["AWS_REGION_NAME"] = original_region_name
|
||||
|
||||
|
||||
titan_embedding_response = {"embedding": [0.1, 0.2, 0.3], "inputTextTokenCount": 10}
|
||||
|
||||
|
||||
def test_bedrock_embedding_uses_correct_region_when_specified():
|
||||
"""
|
||||
Test that when aws_region_name is explicitly passed, it's used correctly
|
||||
even if AWS_REGION_NAME env var is set to a different region.
|
||||
|
||||
relevant issue: https://github.com/BerriAI/litellm/issues/16517
|
||||
"""
|
||||
# Save original env var
|
||||
original_region_name = os.environ.get("AWS_REGION_NAME")
|
||||
|
||||
# Set env var to a different region (this should NOT be used)
|
||||
os.environ["AWS_REGION_NAME"] = "ap-northeast-1"
|
||||
|
||||
try:
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(titan_embedding_response)
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call with explicit region
|
||||
response = litellm.embedding(
|
||||
model="bedrock/amazon.titan-embed-image-v1",
|
||||
input=["test input"],
|
||||
client=client,
|
||||
aws_region_name="us-east-1", # Explicitly set to us-east-1
|
||||
)
|
||||
|
||||
# Verify the request was made to the correct region
|
||||
assert mock_post.called, "HTTP post should have been called"
|
||||
|
||||
# Get the URL from the call
|
||||
call_args = mock_post.call_args
|
||||
url = call_args.kwargs.get("url", "")
|
||||
|
||||
# The URL should contain us-east-1, NOT ap-northeast-1
|
||||
assert "us-east-1" in url, f"URL should contain us-east-1, but got: {url}"
|
||||
assert "ap-northeast-1" not in url, f"URL should NOT contain ap-northeast-1, but got: {url}"
|
||||
|
||||
print(f"✓ Test passed: URL contains correct region: {url}")
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_region_name:
|
||||
os.environ["AWS_REGION_NAME"] = original_region_name
|
||||
else:
|
||||
os.environ.pop("AWS_REGION_NAME", None)
|
||||
|
||||
|
||||
def test_bedrock_embedding_region_bug_reproduction():
|
||||
"""
|
||||
Reproduces the bug where aws_region_name is ignored when passed explicitly.
|
||||
|
||||
relevant issue: https://github.com/BerriAI/litellm/issues/16517
|
||||
"""
|
||||
# Save original env var
|
||||
original_region_name = os.environ.get("AWS_REGION_NAME")
|
||||
|
||||
# Set env var to ap-northeast-1 (this is what the bug report shows)
|
||||
os.environ["AWS_REGION_NAME"] = "ap-northeast-1"
|
||||
|
||||
try:
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.text = json.dumps(titan_embedding_response)
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
# Call with explicit region (as in the bug report)
|
||||
response = litellm.embedding(
|
||||
model="bedrock/amazon.titan-embed-image-v1",
|
||||
input=["test input"],
|
||||
client=client,
|
||||
aws_region_name="us-east-1", # Explicitly set to us-east-1
|
||||
)
|
||||
|
||||
# Verify the request was made
|
||||
assert mock_post.called, "HTTP post should have been called"
|
||||
|
||||
# Get the URL from the call
|
||||
call_args = mock_post.call_args
|
||||
url = call_args.kwargs.get("url", "")
|
||||
|
||||
print(f"Request URL: {url}")
|
||||
print(f"Expected region in URL: us-east-1")
|
||||
print(f"Environment AWS_REGION_NAME: {os.environ.get('AWS_REGION_NAME')}")
|
||||
|
||||
# This assertion will FAIL if the bug exists (it will use ap-northeast-1)
|
||||
# This assertion will PASS if the bug is fixed (it will use us-east-1)
|
||||
if "ap-northeast-1" in url:
|
||||
print("❌ BUG REPRODUCED: Using wrong region from env var instead of explicit parameter")
|
||||
assert False, f"Bug reproduced: URL contains ap-northeast-1 instead of us-east-1. URL: {url}"
|
||||
else:
|
||||
print("✓ Bug NOT reproduced: Using correct region from explicit parameter")
|
||||
assert "us-east-1" in url, f"URL should contain us-east-1, but got: {url}"
|
||||
|
||||
finally:
|
||||
# Restore original env var
|
||||
if original_region_name:
|
||||
os.environ["AWS_REGION_NAME"] = original_region_name
|
||||
else:
|
||||
os.environ.pop("AWS_REGION_NAME", None)
|
||||
Reference in New Issue
Block a user