mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
fix(agentcore): Convert SSE stream iterator to async for proper streaming support (#16293)
* fix(agentcore): support async agentcore runtime streaming * revert: CLAUDE.md * revert: .gitignore * fix: map runtimeUserId to X-Amzn-Bedrock-AgentCore-Runtime-User-Id header for runtime oauth support
This commit is contained in:
@@ -4,9 +4,51 @@ cookbook
|
|||||||
.github
|
.github
|
||||||
tests
|
tests
|
||||||
.git
|
.git
|
||||||
.github
|
|
||||||
.circleci
|
|
||||||
.devcontainer
|
.devcontainer
|
||||||
*.tgz
|
*.tgz
|
||||||
log.txt
|
log.txt
|
||||||
docker/Dockerfile.*
|
docker/Dockerfile.*
|
||||||
|
|
||||||
|
# Claude Flow generated files (must be excluded from Docker build)
|
||||||
|
.claude/
|
||||||
|
.claude-flow/
|
||||||
|
.swarm/
|
||||||
|
.hive-mind/
|
||||||
|
memory/
|
||||||
|
coordination/
|
||||||
|
claude-flow
|
||||||
|
.mcp.json
|
||||||
|
hive-mind-prompt-*.txt
|
||||||
|
|
||||||
|
# Python virtual environments and version managers
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
**/.venv/
|
||||||
|
**/venv/
|
||||||
|
.python-version
|
||||||
|
.pyenv/
|
||||||
|
__pycache__/
|
||||||
|
**/__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.mypy_cache/
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
**/pyvenv.cfg
|
||||||
|
|
||||||
|
# Common project exclusions
|
||||||
|
.vscode
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
env/
|
||||||
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
|
.DS_Store
|
||||||
|
node_modules/
|
||||||
|
*.log
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
|||||||
@@ -19,21 +19,21 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class AgentCoreSSEStreamIterator:
|
class AgentCoreSSEStreamIterator:
|
||||||
"""Iterator for AgentCore SSE streaming responses."""
|
"""Async iterator for AgentCore SSE streaming responses."""
|
||||||
|
|
||||||
def __init__(self, response: httpx.Response, model: str):
|
def __init__(self, response: httpx.Response, model: str):
|
||||||
self.response = response
|
self.response = response
|
||||||
self.model = model
|
self.model = model
|
||||||
self.finished = False
|
self.finished = False
|
||||||
self.line_iterator = self.response.iter_lines()
|
self.line_iterator = self.response.aiter_lines()
|
||||||
|
|
||||||
def __iter__(self):
|
def __aiter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __next__(self) -> ModelResponse:
|
async def __anext__(self) -> ModelResponse:
|
||||||
"""Parse SSE events and yield ModelResponse chunks."""
|
"""Parse SSE events and yield ModelResponse chunks."""
|
||||||
try:
|
try:
|
||||||
for line in self.line_iterator:
|
async for line in self.line_iterator:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
|
||||||
if not line or not line.startswith('data:'):
|
if not line or not line.startswith('data:'):
|
||||||
@@ -134,17 +134,17 @@ class AgentCoreSSEStreamIterator:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Stream ended naturally
|
# Stream ended naturally
|
||||||
raise StopIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
except StopIteration:
|
except StopAsyncIteration:
|
||||||
raise
|
raise
|
||||||
except httpx.StreamConsumed:
|
except httpx.StreamConsumed:
|
||||||
# This is expected when the stream has been fully consumed
|
# This is expected when the stream has been fully consumed
|
||||||
raise StopIteration
|
raise StopAsyncIteration
|
||||||
except httpx.StreamClosed:
|
except httpx.StreamClosed:
|
||||||
# This is expected when the stream is closed
|
# This is expected when the stream is closed
|
||||||
raise StopIteration
|
raise StopAsyncIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.error(f"Error in AgentCore SSE stream: {str(e)}")
|
verbose_logger.error(f"Error in AgentCore SSE stream: {str(e)}")
|
||||||
raise StopIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|||||||
@@ -158,10 +158,16 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
|||||||
session_id = optional_params.get("runtimeSessionId", None)
|
session_id = optional_params.get("runtimeSessionId", None)
|
||||||
if session_id:
|
if session_id:
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
# Generate a session ID with 33+ characters
|
# Generate a session ID with 33+ characters
|
||||||
return f"litellm-session-{str(uuid.uuid4())}"
|
return f"litellm-session-{str(uuid.uuid4())}"
|
||||||
|
|
||||||
|
def _get_runtime_user_id(self, optional_params: dict) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get runtime user ID if provided
|
||||||
|
"""
|
||||||
|
return optional_params.get("runtimeUserId", None)
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -172,28 +178,34 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Transform the request to AgentCore format.
|
Transform the request to AgentCore format.
|
||||||
|
|
||||||
Based on boto3's implementation:
|
Based on boto3's implementation:
|
||||||
- Session ID goes in header: X-Amzn-Bedrock-AgentCore-Runtime-Session-Id
|
- Session ID goes in header: X-Amzn-Bedrock-AgentCore-Runtime-Session-Id
|
||||||
|
- User ID goes in header: X-Amzn-Bedrock-AgentCore-Runtime-User-Id
|
||||||
- Qualifier goes as query parameter
|
- Qualifier goes as query parameter
|
||||||
- Only the payload goes in the request body
|
- Only the payload goes in the request body
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Payload dict containing the prompt
|
dict: Payload dict containing the prompt
|
||||||
"""
|
"""
|
||||||
# Use the last message content as the prompt
|
# Use the last message content as the prompt
|
||||||
prompt = convert_content_list_to_str(messages[-1])
|
prompt = convert_content_list_to_str(messages[-1])
|
||||||
|
|
||||||
# Create the payload - this is what goes in the body (raw JSON)
|
# Create the payload - this is what goes in the body (raw JSON)
|
||||||
payload: dict = {"prompt": prompt}
|
payload: dict = {"prompt": prompt}
|
||||||
|
|
||||||
# Get or generate session ID - this goes in the header
|
# Get or generate session ID - this goes in the header
|
||||||
runtime_session_id = self._get_runtime_session_id(optional_params)
|
runtime_session_id = self._get_runtime_session_id(optional_params)
|
||||||
headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] = runtime_session_id
|
headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] = runtime_session_id
|
||||||
|
|
||||||
|
# Get user ID if provided - this goes in the header
|
||||||
|
runtime_user_id = self._get_runtime_user_id(optional_params)
|
||||||
|
if runtime_user_id:
|
||||||
|
headers["X-Amzn-Bedrock-AgentCore-Runtime-User-Id"] = runtime_user_id
|
||||||
|
|
||||||
# The request data is the payload dict (will be JSON encoded by the HTTP handler)
|
# The request data is the payload dict (will be JSON encoded by the HTTP handler)
|
||||||
# Qualifier will be handled as a query parameter in get_complete_url
|
# Qualifier will be handled as a query parameter in get_complete_url
|
||||||
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
def _extract_sse_json(self, line: str) -> Optional[Dict]:
|
def _extract_sse_json(self, line: str) -> Optional[Dict]:
|
||||||
@@ -480,6 +492,67 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
|||||||
|
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
|
async def get_async_custom_stream_wrapper(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
custom_llm_provider: str,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
api_base: str,
|
||||||
|
headers: dict,
|
||||||
|
data: dict,
|
||||||
|
messages: list,
|
||||||
|
client: Optional["AsyncHTTPHandler"] = None,
|
||||||
|
json_mode: Optional[bool] = None,
|
||||||
|
signed_json_body: Optional[bytes] = None,
|
||||||
|
) -> CustomStreamWrapper:
|
||||||
|
"""
|
||||||
|
Get a CustomStreamWrapper for asynchronous streaming.
|
||||||
|
|
||||||
|
This is called when stream=True is passed to acompletion().
|
||||||
|
"""
|
||||||
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
|
AsyncHTTPHandler,
|
||||||
|
get_async_httpx_client,
|
||||||
|
)
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
|
||||||
|
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||||
|
client = get_async_httpx_client(llm_provider="bedrock", params={})
|
||||||
|
|
||||||
|
# Make async streaming request
|
||||||
|
response = await client.post(
|
||||||
|
api_base,
|
||||||
|
headers=headers,
|
||||||
|
data=signed_json_body if signed_json_body else json.dumps(data),
|
||||||
|
stream=True, # THIS IS KEY - tells httpx to not buffer
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise BedrockError(
|
||||||
|
status_code=response.status_code, message=str(await response.aread())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create iterator for SSE stream
|
||||||
|
completion_stream = self.get_streaming_response(model=model, raw_response=response)
|
||||||
|
|
||||||
|
streaming_response = CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key="",
|
||||||
|
original_response="first stream response received",
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
|
return streaming_response
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_custom_stream_wrapper(self) -> bool:
|
def has_custom_stream_wrapper(self) -> bool:
|
||||||
"""Indicates that this config has custom streaming support."""
|
"""Indicates that this config has custom streaming support."""
|
||||||
|
|||||||
Reference in New Issue
Block a user