mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Feat] Agent Gateway - Allow tracking request / response in "Logs" Page (#17449)
* init litellm A2a client * simpler a2a client interface * test a2a * move a2a invoking tests * test fix * ensure a2a send message is tracked n logs * rename tags * add streaming handlng * add a2a invocation * add a2a invocation i cost calc * test_a2a_logging_payload * update invoke_agent_a2a * test_invoke_agent_a2a_adds_litellm_data * add A2a agent
This commit is contained in:
36
litellm/a2a/cost_calculator.py
Normal file
36
litellm/a2a/cost_calculator.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Cost calculator for A2A (Agent-to-Agent) calls.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class A2ACostCalculator:
|
||||
@staticmethod
|
||||
def calculate_a2a_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an A2A send_message call.
|
||||
|
||||
Default is 0.0. In the future, users can configure cost per agent call.
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
# Check if user set a custom response cost
|
||||
response_cost = litellm_logging_obj.model_call_details.get(
|
||||
"response_cost", None
|
||||
)
|
||||
if response_cost is not None:
|
||||
return response_cost
|
||||
|
||||
# Default to 0.0 for A2A calls
|
||||
return 0.0
|
||||
@@ -38,6 +38,38 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract agent info and set model/custom_llm_provider for cost tracking.
|
||||
|
||||
Sets model info on the litellm_logging_obj if available.
|
||||
Returns the agent name for logging.
|
||||
"""
|
||||
agent_name = "unknown"
|
||||
|
||||
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
|
||||
if agent_card is None:
|
||||
agent_card = getattr(a2a_client, "agent_card", None)
|
||||
|
||||
if agent_card is not None:
|
||||
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
|
||||
|
||||
# Build model string
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
custom_llm_provider = "a2a_agent"
|
||||
|
||||
# Set on litellm_logging_obj if available (for standard logging payload)
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: "A2AClientType",
|
||||
@@ -80,7 +112,9 @@ async def asend_message(
|
||||
response = await asend_message(a2a_client=a2a_client, request=request)
|
||||
```
|
||||
"""
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}")
|
||||
agent_name = _get_a2a_model_info(a2a_client, kwargs)
|
||||
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
|
||||
|
||||
a2a_response = await a2a_client.send_message(request)
|
||||
|
||||
@@ -207,12 +241,15 @@ async def create_a2a_client(
|
||||
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
|
||||
# Create and return A2A client
|
||||
# Create A2A client
|
||||
a2a_client = _A2AClient(
|
||||
httpx_client=httpx_client,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
# Store agent_card on client for later retrieval (SDK doesn't expose it)
|
||||
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
|
||||
|
||||
verbose_logger.info(f"A2A client created for {base_url}")
|
||||
|
||||
return a2a_client
|
||||
|
||||
@@ -934,6 +934,17 @@ def completion_cost( # noqa: PLR0915
|
||||
prompt_tokens = token_counter(model=model, text=prompt)
|
||||
completion_tokens = token_counter(model=model, text=completion)
|
||||
|
||||
# Handle A2A calls before model check - A2A doesn't require a model
|
||||
if call_type in (
|
||||
CallTypes.asend_message.value,
|
||||
CallTypes.send_message.value,
|
||||
):
|
||||
from litellm.a2a.cost_calculator import A2ACostCalculator
|
||||
|
||||
return A2ACostCalculator.calculate_a2a_cost(
|
||||
litellm_logging_obj=litellm_logging_obj
|
||||
)
|
||||
|
||||
if model is None:
|
||||
raise ValueError(
|
||||
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"
|
||||
|
||||
BIN
litellm/proxy/_experimental/out/assets/logos/a2a_agent.png
Normal file
BIN
litellm/proxy/_experimental/out/assets/logos/a2a_agent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 71 KiB |
@@ -8,7 +8,7 @@ The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@@ -45,24 +45,6 @@ def _get_agent(agent_id: str):
|
||||
return agent
|
||||
|
||||
|
||||
async def _handle_send_message(
|
||||
a2a_client: Any,
|
||||
request_id: str,
|
||||
params: dict,
|
||||
) -> JSONResponse:
|
||||
"""Handle message/send method."""
|
||||
from a2a.types import MessageSendParams, SendMessageRequest
|
||||
|
||||
from litellm.a2a import asend_message
|
||||
|
||||
a2a_request = SendMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
response = await asend_message(a2a_client=a2a_client, request=a2a_request)
|
||||
return JSONResponse(content=response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
|
||||
async def _handle_stream_message(
|
||||
a2a_client: Any,
|
||||
request_id: str,
|
||||
@@ -136,6 +118,7 @@ async def get_agent_card(
|
||||
async def invoke_agent_a2a(
|
||||
agent_id: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
@@ -145,7 +128,13 @@ async def invoke_agent_a2a(
|
||||
- message/send: Send a message and get a response
|
||||
- message/stream: Send a message and stream the response
|
||||
"""
|
||||
from litellm.a2a import create_a2a_client
|
||||
from litellm.a2a import asend_message, create_a2a_client
|
||||
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
proxy_config,
|
||||
version,
|
||||
)
|
||||
|
||||
body = {}
|
||||
try:
|
||||
@@ -165,18 +154,50 @@ async def invoke_agent_a2a(
|
||||
if agent is None:
|
||||
return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' not found", 404)
|
||||
|
||||
# Get backend URL
|
||||
# Get backend URL and agent name
|
||||
agent_url = agent.agent_card_params.get("url")
|
||||
agent_name = agent.agent_card_params.get("name", agent_id)
|
||||
if not agent_url:
|
||||
return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500)
|
||||
|
||||
verbose_proxy_logger.info(f"Proxying A2A request to agent '{agent_id}' at {agent_url}")
|
||||
|
||||
# Create A2A client and dispatch to handler
|
||||
# Set up data dict for litellm processing
|
||||
body.update({
|
||||
"model": f"a2a_agent/{agent_name}",
|
||||
"custom_llm_provider": "a2a_agent",
|
||||
})
|
||||
|
||||
# Add litellm data (user_api_key, user_id, team_id, etc.)
|
||||
data = await add_litellm_data_to_request(
|
||||
data=body,
|
||||
request=request,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_config=proxy_config,
|
||||
general_settings=general_settings,
|
||||
version=version,
|
||||
)
|
||||
|
||||
# Create A2A client
|
||||
a2a_client = await create_a2a_client(base_url=agent_url)
|
||||
|
||||
if method == "message/send":
|
||||
return await _handle_send_message(a2a_client, request_id, params)
|
||||
from a2a.types import MessageSendParams, SendMessageRequest
|
||||
|
||||
a2a_request = SendMessageRequest(
|
||||
id=request_id,
|
||||
params=MessageSendParams(**params),
|
||||
)
|
||||
|
||||
# Pass litellm data through kwargs for proper logging
|
||||
response = await asend_message(
|
||||
a2a_client=a2a_client,
|
||||
request=a2a_request,
|
||||
metadata=data.get("metadata", {}),
|
||||
proxy_server_request=data.get("proxy_server_request"),
|
||||
)
|
||||
return JSONResponse(content=response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
elif method == "message/stream":
|
||||
return await _handle_stream_message(a2a_client, request_id, params)
|
||||
else:
|
||||
|
||||
@@ -345,6 +345,12 @@ class CallTypes(str, Enum):
|
||||
#########################################################
|
||||
call_mcp_tool = "call_mcp_tool"
|
||||
|
||||
#########################################################
|
||||
# A2A Call Types
|
||||
#########################################################
|
||||
asend_message = "asend_message"
|
||||
send_message = "send_message"
|
||||
|
||||
|
||||
CallTypesLiteral = Literal[
|
||||
"embedding",
|
||||
@@ -397,6 +403,8 @@ CallTypesLiteral = Literal[
|
||||
"vector_store_file_delete",
|
||||
"avector_store_file_delete",
|
||||
"call_mcp_tool",
|
||||
"asend_message",
|
||||
"send_message",
|
||||
"aresponses",
|
||||
"responses",
|
||||
]
|
||||
@@ -712,6 +720,8 @@ API_ROUTE_TO_CALL_TYPES = {
|
||||
],
|
||||
# MCP (Model Context Protocol)
|
||||
"/mcp/call_tool": [CallTypes.call_mcp_tool],
|
||||
# A2A (Agent-to-Agent)
|
||||
"/a2a/{agent_id}": [CallTypes.asend_message, CallTypes.send_message],
|
||||
# Passthrough endpoints
|
||||
"/llm_passthrough": [
|
||||
CallTypes.llm_passthrough_route,
|
||||
@@ -2981,6 +2991,7 @@ class LlmProviders(str, Enum):
|
||||
WANDB = "wandb"
|
||||
OVHCLOUD = "ovhcloud"
|
||||
LEMONADE = "lemonade"
|
||||
A2A_AGENT = "a2a_agent"
|
||||
|
||||
|
||||
# Create a set of all provider values for quick lookup
|
||||
|
||||
@@ -133,6 +133,35 @@ async def test_a2a_logging_payload():
|
||||
print(f"standard_logging_payload: {test_logger.standard_logging_payload}")
|
||||
print(f"logged kwargs: {json.dumps(test_logger.logged_kwargs, indent=4, default=str)}")
|
||||
|
||||
# Verify logging was called
|
||||
assert test_logger.log_success_called is True
|
||||
assert test_logger.standard_logging_payload is not None
|
||||
|
||||
|
||||
# Verify standard_logging_payload exists
|
||||
slp = test_logger.standard_logging_payload
|
||||
assert slp is not None
|
||||
|
||||
# Get values from standard logging payload
|
||||
logged_model = slp.get("model") if isinstance(slp, dict) else getattr(slp, "model", None)
|
||||
logged_provider = slp.get("custom_llm_provider") if isinstance(slp, dict) else getattr(slp, "custom_llm_provider", None)
|
||||
call_type = slp.get("call_type") if isinstance(slp, dict) else getattr(slp, "call_type", None)
|
||||
response_cost = slp.get("response_cost") if isinstance(slp, dict) else getattr(slp, "response_cost", None)
|
||||
|
||||
print(f"\n=== Standard Logging Payload Validation ===")
|
||||
print(f"model: {logged_model}")
|
||||
print(f"custom_llm_provider: {logged_provider}")
|
||||
print(f"call_type: {call_type}")
|
||||
print(f"response_cost: {response_cost}")
|
||||
|
||||
# Verify model and custom_llm_provider are set correctly
|
||||
assert logged_model is not None, "model should be set"
|
||||
assert "a2a_agent/" in logged_model, f"model should contain 'a2a_agent/', got: {logged_model}"
|
||||
assert logged_provider == "a2a_agent", f"custom_llm_provider should be 'a2a_agent', got: {logged_provider}"
|
||||
|
||||
# Verify call_type is correct for A2A
|
||||
assert call_type == "asend_message", f"call_type should be 'asend_message', got: {call_type}"
|
||||
|
||||
# Verify response_cost is set to 0.0 (not None, not an error)
|
||||
# This confirms the A2A cost calculator is working
|
||||
assert response_cost is not None, "response_cost should not be None"
|
||||
assert response_cost == 0.0, f"response_cost should be 0.0 for A2A, got: {response_cost}"
|
||||
|
||||
114
tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py
Normal file
114
tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
Mock tests for A2A endpoints.
|
||||
|
||||
Tests that invoke_agent_a2a properly integrates with add_litellm_data_to_request.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_a2a_adds_litellm_data():
|
||||
"""
|
||||
Test that invoke_agent_a2a calls add_litellm_data_to_request
|
||||
and the resulting data includes proxy_server_request.
|
||||
"""
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
# Track the data passed to add_litellm_data_to_request
|
||||
captured_data = {}
|
||||
|
||||
async def mock_add_litellm_data(data, **kwargs):
|
||||
# Simulate what add_litellm_data_to_request does
|
||||
data["proxy_server_request"] = {
|
||||
"url": "http://localhost:4000/a2a/test-agent",
|
||||
"method": "POST",
|
||||
"headers": {},
|
||||
"body": dict(data),
|
||||
}
|
||||
captured_data.update(data)
|
||||
return data
|
||||
|
||||
# Mock response from asend_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.model_dump.return_value = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"result": {"status": "success"},
|
||||
}
|
||||
|
||||
# Mock agent
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.agent_card_params = {
|
||||
"url": "http://backend-agent:10001",
|
||||
"name": "Test Agent",
|
||||
}
|
||||
|
||||
# Mock request
|
||||
mock_request = MagicMock()
|
||||
mock_request.json = AsyncMock(return_value={
|
||||
"jsonrpc": "2.0",
|
||||
"id": "test-id",
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello"}],
|
||||
"messageId": "msg-123",
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
mock_user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="sk-test-key",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
)
|
||||
|
||||
# Patch at the source modules
|
||||
with patch(
|
||||
"litellm.proxy.agent_endpoints.a2a_endpoints._get_agent",
|
||||
return_value=mock_agent,
|
||||
), patch(
|
||||
"litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request",
|
||||
side_effect=mock_add_litellm_data,
|
||||
) as mock_add_data, patch(
|
||||
"litellm.a2a.create_a2a_client",
|
||||
new_callable=AsyncMock,
|
||||
), patch(
|
||||
"litellm.a2a.asend_message",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{},
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.proxy_config",
|
||||
MagicMock(),
|
||||
), patch(
|
||||
"litellm.proxy.proxy_server.version",
|
||||
"1.0.0",
|
||||
):
|
||||
from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a
|
||||
|
||||
mock_fastapi_response = MagicMock()
|
||||
|
||||
result = await invoke_agent_a2a(
|
||||
agent_id="test-agent",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_fastapi_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify add_litellm_data_to_request was called
|
||||
mock_add_data.assert_called_once()
|
||||
|
||||
# Verify model and custom_llm_provider were set
|
||||
assert captured_data.get("model") == "a2a_agent/Test Agent"
|
||||
assert captured_data.get("custom_llm_provider") == "a2a_agent"
|
||||
|
||||
# Verify proxy_server_request was added
|
||||
assert "proxy_server_request" in captured_data
|
||||
assert captured_data["proxy_server_request"]["method"] == "POST"
|
||||
BIN
ui/litellm-dashboard/public/assets/logos/a2a_agent.png
Normal file
BIN
ui/litellm-dashboard/public/assets/logos/a2a_agent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 71 KiB |
@@ -1,4 +1,6 @@
|
||||
|
||||
export enum Providers {
|
||||
A2A_Agent = "A2A Agent",
|
||||
AIML = "AI/ML API",
|
||||
Bedrock = "Amazon Bedrock",
|
||||
Anthropic = "Anthropic",
|
||||
@@ -43,6 +45,7 @@ export enum Providers {
|
||||
}
|
||||
|
||||
export const provider_map: Record<string, string> = {
|
||||
A2A_Agent: "a2a_agent",
|
||||
AIML: "aiml",
|
||||
OpenAI: "openai",
|
||||
OpenAI_Text: "text-completion-openai",
|
||||
@@ -89,6 +92,7 @@ export const provider_map: Record<string, string> = {
|
||||
const asset_logos_folder = "../ui/assets/logos/";
|
||||
|
||||
export const providerLogoMap: Record<string, string> = {
|
||||
[Providers.A2A_Agent]: `${asset_logos_folder}a2a_agent.png`,
|
||||
[Providers.AIML]: `${asset_logos_folder}aiml_api.svg`,
|
||||
[Providers.Anthropic]: `${asset_logos_folder}anthropic.svg`,
|
||||
[Providers.AssemblyAI]: `${asset_logos_folder}assemblyai_small.png`,
|
||||
|
||||
Reference in New Issue
Block a user