mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
fix(agentcore): Convert SSE stream iterator to async for proper streaming support (#16293)
* fix(agentcore): support async agentcore runtime streaming * revert: CLAUDE.md * revert: .gitignore * fix: map runtimeUserId to X-Amzn-Bedrock-AgentCore-Runtime-User-Id header for runtime oauth support
This commit is contained in:
@@ -4,9 +4,51 @@ cookbook
|
||||
.github
|
||||
tests
|
||||
.git
|
||||
.github
|
||||
.circleci
|
||||
.devcontainer
|
||||
*.tgz
|
||||
log.txt
|
||||
docker/Dockerfile.*
|
||||
|
||||
# Claude Flow generated files (must be excluded from Docker build)
|
||||
.claude/
|
||||
.claude-flow/
|
||||
.swarm/
|
||||
.hive-mind/
|
||||
memory/
|
||||
coordination/
|
||||
claude-flow
|
||||
.mcp.json
|
||||
hive-mind-prompt-*.txt
|
||||
|
||||
# Python virtual environments and version managers
|
||||
.venv/
|
||||
venv/
|
||||
**/.venv/
|
||||
**/venv/
|
||||
.python-version
|
||||
.pyenv/
|
||||
__pycache__/
|
||||
**/__pycache__/
|
||||
*.pyc
|
||||
.mypy_cache/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
**/pyvenv.cfg
|
||||
|
||||
# Common project exclusions
|
||||
.vscode
|
||||
*.pyo
|
||||
*.pyd
|
||||
.Python
|
||||
env/
|
||||
.pytest_cache
|
||||
.coverage
|
||||
htmlcov/
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
.DS_Store
|
||||
node_modules/
|
||||
*.log
|
||||
.env
|
||||
.env.local
|
||||
|
||||
@@ -19,21 +19,21 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class AgentCoreSSEStreamIterator:
|
||||
"""Iterator for AgentCore SSE streaming responses."""
|
||||
|
||||
"""Async iterator for AgentCore SSE streaming responses."""
|
||||
|
||||
def __init__(self, response: httpx.Response, model: str):
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.finished = False
|
||||
self.line_iterator = self.response.iter_lines()
|
||||
|
||||
def __iter__(self):
|
||||
self.line_iterator = self.response.aiter_lines()
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> ModelResponse:
|
||||
|
||||
async def __anext__(self) -> ModelResponse:
|
||||
"""Parse SSE events and yield ModelResponse chunks."""
|
||||
try:
|
||||
for line in self.line_iterator:
|
||||
async for line in self.line_iterator:
|
||||
line = line.strip()
|
||||
|
||||
if not line or not line.startswith('data:'):
|
||||
@@ -134,17 +134,17 @@ class AgentCoreSSEStreamIterator:
|
||||
continue
|
||||
|
||||
# Stream ended naturally
|
||||
raise StopIteration
|
||||
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration
|
||||
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except httpx.StreamConsumed:
|
||||
# This is expected when the stream has been fully consumed
|
||||
raise StopIteration
|
||||
raise StopAsyncIteration
|
||||
except httpx.StreamClosed:
|
||||
# This is expected when the stream is closed
|
||||
raise StopIteration
|
||||
raise StopAsyncIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in AgentCore SSE stream: {str(e)}")
|
||||
raise StopIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
@@ -158,10 +158,16 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
session_id = optional_params.get("runtimeSessionId", None)
|
||||
if session_id:
|
||||
return session_id
|
||||
|
||||
|
||||
# Generate a session ID with 33+ characters
|
||||
return f"litellm-session-{str(uuid.uuid4())}"
|
||||
|
||||
def _get_runtime_user_id(self, optional_params: dict) -> Optional[str]:
|
||||
"""
|
||||
Get runtime user ID if provided
|
||||
"""
|
||||
return optional_params.get("runtimeUserId", None)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
@@ -172,28 +178,34 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request to AgentCore format.
|
||||
|
||||
|
||||
Based on boto3's implementation:
|
||||
- Session ID goes in header: X-Amzn-Bedrock-AgentCore-Runtime-Session-Id
|
||||
- User ID goes in header: X-Amzn-Bedrock-AgentCore-Runtime-User-Id
|
||||
- Qualifier goes as query parameter
|
||||
- Only the payload goes in the request body
|
||||
|
||||
|
||||
Returns:
|
||||
dict: Payload dict containing the prompt
|
||||
"""
|
||||
# Use the last message content as the prompt
|
||||
prompt = convert_content_list_to_str(messages[-1])
|
||||
|
||||
|
||||
# Create the payload - this is what goes in the body (raw JSON)
|
||||
payload: dict = {"prompt": prompt}
|
||||
|
||||
|
||||
# Get or generate session ID - this goes in the header
|
||||
runtime_session_id = self._get_runtime_session_id(optional_params)
|
||||
headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] = runtime_session_id
|
||||
|
||||
|
||||
# Get user ID if provided - this goes in the header
|
||||
runtime_user_id = self._get_runtime_user_id(optional_params)
|
||||
if runtime_user_id:
|
||||
headers["X-Amzn-Bedrock-AgentCore-Runtime-User-Id"] = runtime_user_id
|
||||
|
||||
# The request data is the payload dict (will be JSON encoded by the HTTP handler)
|
||||
# Qualifier will be handled as a query parameter in get_complete_url
|
||||
|
||||
|
||||
return payload
|
||||
|
||||
def _extract_sse_json(self, line: str) -> Optional[Dict]:
|
||||
@@ -480,6 +492,67 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
|
||||
return streaming_response
|
||||
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional["AsyncHTTPHandler"] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
"""
|
||||
Get a CustomStreamWrapper for asynchronous streaming.
|
||||
|
||||
This is called when stream=True is passed to acompletion().
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(llm_provider="bedrock", params={})
|
||||
|
||||
# Make async streaming request
|
||||
response = await client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=signed_json_body if signed_json_body else json.dumps(data),
|
||||
stream=True, # THIS IS KEY - tells httpx to not buffer
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(await response.aread())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(model=model, raw_response=response)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Indicates that this config has custom streaming support."""
|
||||
|
||||
Reference in New Issue
Block a user