[Feat] Agent Gateway - Allow tracking request / response in "Logs" Page (#17449)

* init litellm A2a client

* simpler a2a client interface

* test a2a

* move a2a invoking tests

* test fix

* ensure a2a send message is tracked n logs

* rename tags

* add streaming handlng

* add a2a invocation

* add a2a invocation i cost calc

* test_a2a_logging_payload

* update invoke_agent_a2a

* test_invoke_agent_a2a_adds_litellm_data

* add A2a agent
This commit is contained in:
Ishaan Jaff
2025-12-03 18:57:18 -08:00
committed by GitHub
parent 4370f6fb74
commit 585aee2ae4
11 changed files with 289 additions and 26 deletions

View File

@@ -0,0 +1,36 @@
"""
Cost calculator for A2A (Agent-to-Agent) calls.
"""
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LitellmLoggingObject,
)
else:
LitellmLoggingObject = Any
class A2ACostCalculator:
@staticmethod
def calculate_a2a_cost(
litellm_logging_obj: Optional[LitellmLoggingObject],
) -> float:
"""
Calculate the cost of an A2A send_message call.
Default is 0.0. In the future, users can configure cost per agent call.
"""
if litellm_logging_obj is None:
return 0.0
# Check if user set a custom response cost
response_cost = litellm_logging_obj.model_call_details.get(
"response_cost", None
)
if response_cost is not None:
return response_cost
# Default to 0.0 for A2A calls
return 0.0

View File

@@ -38,6 +38,38 @@ except ImportError:
pass
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
"""
Extract agent info and set model/custom_llm_provider for cost tracking.
Sets model info on the litellm_logging_obj if available.
Returns the agent name for logging.
"""
agent_name = "unknown"
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
if agent_card is None:
agent_card = getattr(a2a_client, "agent_card", None)
if agent_card is not None:
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
# Build model string
model = f"a2a_agent/{agent_name}"
custom_llm_provider = "a2a_agent"
# Set on litellm_logging_obj if available (for standard logging payload)
litellm_logging_obj = kwargs.get("litellm_logging_obj")
if litellm_logging_obj is not None:
litellm_logging_obj.model = model
litellm_logging_obj.custom_llm_provider = custom_llm_provider
litellm_logging_obj.model_call_details["model"] = model
litellm_logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
return agent_name
@client
async def asend_message(
a2a_client: "A2AClientType",
@@ -80,7 +112,9 @@ async def asend_message(
response = await asend_message(a2a_client=a2a_client, request=request)
```
"""
verbose_logger.info(f"A2A send_message request_id={request.id}")
agent_name = _get_a2a_model_info(a2a_client, kwargs)
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
a2a_response = await a2a_client.send_message(request)
@@ -207,12 +241,15 @@ async def create_a2a_client(
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
)
# Create and return A2A client
# Create A2A client
a2a_client = _A2AClient(
httpx_client=httpx_client,
agent_card=agent_card,
)
# Store agent_card on client for later retrieval (SDK doesn't expose it)
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
verbose_logger.info(f"A2A client created for {base_url}")
return a2a_client

View File

@@ -934,6 +934,17 @@ def completion_cost( # noqa: PLR0915
prompt_tokens = token_counter(model=model, text=prompt)
completion_tokens = token_counter(model=model, text=completion)
# Handle A2A calls before model check - A2A doesn't require a model
if call_type in (
CallTypes.asend_message.value,
CallTypes.send_message.value,
):
from litellm.a2a.cost_calculator import A2ACostCalculator
return A2ACostCalculator.calculate_a2a_cost(
litellm_logging_obj=litellm_logging_obj
)
if model is None:
raise ValueError(
f"Model is None and does not exist in passed completion_response. Passed completion_response={completion_response}, model={model}"

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

View File

@@ -8,7 +8,7 @@ The A2A SDK can point to LiteLLM's URL and invoke agents registered with LiteLLM
import json
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from litellm._logging import verbose_proxy_logger
@@ -45,24 +45,6 @@ def _get_agent(agent_id: str):
return agent
async def _handle_send_message(
a2a_client: Any,
request_id: str,
params: dict,
) -> JSONResponse:
"""Handle message/send method."""
from a2a.types import MessageSendParams, SendMessageRequest
from litellm.a2a import asend_message
a2a_request = SendMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
response = await asend_message(a2a_client=a2a_client, request=a2a_request)
return JSONResponse(content=response.model_dump(mode="json", exclude_none=True))
async def _handle_stream_message(
a2a_client: Any,
request_id: str,
@@ -136,6 +118,7 @@ async def get_agent_card(
async def invoke_agent_a2a(
agent_id: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
@@ -145,7 +128,13 @@ async def invoke_agent_a2a(
- message/send: Send a message and get a response
- message/stream: Send a message and stream the response
"""
from litellm.a2a import create_a2a_client
from litellm.a2a import asend_message, create_a2a_client
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.proxy_server import (
general_settings,
proxy_config,
version,
)
body = {}
try:
@@ -165,18 +154,50 @@ async def invoke_agent_a2a(
if agent is None:
return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' not found", 404)
# Get backend URL
# Get backend URL and agent name
agent_url = agent.agent_card_params.get("url")
agent_name = agent.agent_card_params.get("name", agent_id)
if not agent_url:
return _jsonrpc_error(request_id, -32000, f"Agent '{agent_id}' has no URL configured", 500)
verbose_proxy_logger.info(f"Proxying A2A request to agent '{agent_id}' at {agent_url}")
# Create A2A client and dispatch to handler
# Set up data dict for litellm processing
body.update({
"model": f"a2a_agent/{agent_name}",
"custom_llm_provider": "a2a_agent",
})
# Add litellm data (user_api_key, user_id, team_id, etc.)
data = await add_litellm_data_to_request(
data=body,
request=request,
user_api_key_dict=user_api_key_dict,
proxy_config=proxy_config,
general_settings=general_settings,
version=version,
)
# Create A2A client
a2a_client = await create_a2a_client(base_url=agent_url)
if method == "message/send":
return await _handle_send_message(a2a_client, request_id, params)
from a2a.types import MessageSendParams, SendMessageRequest
a2a_request = SendMessageRequest(
id=request_id,
params=MessageSendParams(**params),
)
# Pass litellm data through kwargs for proper logging
response = await asend_message(
a2a_client=a2a_client,
request=a2a_request,
metadata=data.get("metadata", {}),
proxy_server_request=data.get("proxy_server_request"),
)
return JSONResponse(content=response.model_dump(mode="json", exclude_none=True))
elif method == "message/stream":
return await _handle_stream_message(a2a_client, request_id, params)
else:

View File

@@ -345,6 +345,12 @@ class CallTypes(str, Enum):
#########################################################
call_mcp_tool = "call_mcp_tool"
#########################################################
# A2A Call Types
#########################################################
asend_message = "asend_message"
send_message = "send_message"
CallTypesLiteral = Literal[
"embedding",
@@ -397,6 +403,8 @@ CallTypesLiteral = Literal[
"vector_store_file_delete",
"avector_store_file_delete",
"call_mcp_tool",
"asend_message",
"send_message",
"aresponses",
"responses",
]
@@ -712,6 +720,8 @@ API_ROUTE_TO_CALL_TYPES = {
],
# MCP (Model Context Protocol)
"/mcp/call_tool": [CallTypes.call_mcp_tool],
# A2A (Agent-to-Agent)
"/a2a/{agent_id}": [CallTypes.asend_message, CallTypes.send_message],
# Passthrough endpoints
"/llm_passthrough": [
CallTypes.llm_passthrough_route,
@@ -2981,6 +2991,7 @@ class LlmProviders(str, Enum):
WANDB = "wandb"
OVHCLOUD = "ovhcloud"
LEMONADE = "lemonade"
A2A_AGENT = "a2a_agent"
# Create a set of all provider values for quick lookup

View File

@@ -133,6 +133,35 @@ async def test_a2a_logging_payload():
print(f"standard_logging_payload: {test_logger.standard_logging_payload}")
print(f"logged kwargs: {json.dumps(test_logger.logged_kwargs, indent=4, default=str)}")
# Verify logging was called
assert test_logger.log_success_called is True
assert test_logger.standard_logging_payload is not None
# Verify standard_logging_payload exists
slp = test_logger.standard_logging_payload
assert slp is not None
# Get values from standard logging payload
logged_model = slp.get("model") if isinstance(slp, dict) else getattr(slp, "model", None)
logged_provider = slp.get("custom_llm_provider") if isinstance(slp, dict) else getattr(slp, "custom_llm_provider", None)
call_type = slp.get("call_type") if isinstance(slp, dict) else getattr(slp, "call_type", None)
response_cost = slp.get("response_cost") if isinstance(slp, dict) else getattr(slp, "response_cost", None)
print(f"\n=== Standard Logging Payload Validation ===")
print(f"model: {logged_model}")
print(f"custom_llm_provider: {logged_provider}")
print(f"call_type: {call_type}")
print(f"response_cost: {response_cost}")
# Verify model and custom_llm_provider are set correctly
assert logged_model is not None, "model should be set"
assert "a2a_agent/" in logged_model, f"model should contain 'a2a_agent/', got: {logged_model}"
assert logged_provider == "a2a_agent", f"custom_llm_provider should be 'a2a_agent', got: {logged_provider}"
# Verify call_type is correct for A2A
assert call_type == "asend_message", f"call_type should be 'asend_message', got: {call_type}"
# Verify response_cost is set to 0.0 (not None, not an error)
# This confirms the A2A cost calculator is working
assert response_cost is not None, "response_cost should not be None"
assert response_cost == 0.0, f"response_cost should be 0.0 for A2A, got: {response_cost}"

View File

@@ -0,0 +1,114 @@
"""
Mock tests for A2A endpoints.
Tests that invoke_agent_a2a properly integrates with add_litellm_data_to_request.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@pytest.mark.asyncio
async def test_invoke_agent_a2a_adds_litellm_data():
"""
Test that invoke_agent_a2a calls add_litellm_data_to_request
and the resulting data includes proxy_server_request.
"""
from litellm.proxy._types import UserAPIKeyAuth
# Track the data passed to add_litellm_data_to_request
captured_data = {}
async def mock_add_litellm_data(data, **kwargs):
# Simulate what add_litellm_data_to_request does
data["proxy_server_request"] = {
"url": "http://localhost:4000/a2a/test-agent",
"method": "POST",
"headers": {},
"body": dict(data),
}
captured_data.update(data)
return data
# Mock response from asend_message
mock_response = MagicMock()
mock_response.model_dump.return_value = {
"jsonrpc": "2.0",
"id": "test-id",
"result": {"status": "success"},
}
# Mock agent
mock_agent = MagicMock()
mock_agent.agent_card_params = {
"url": "http://backend-agent:10001",
"name": "Test Agent",
}
# Mock request
mock_request = MagicMock()
mock_request.json = AsyncMock(return_value={
"jsonrpc": "2.0",
"id": "test-id",
"method": "message/send",
"params": {
"message": {
"role": "user",
"parts": [{"kind": "text", "text": "Hello"}],
"messageId": "msg-123",
}
},
})
mock_user_api_key_dict = UserAPIKeyAuth(
api_key="sk-test-key",
user_id="test-user",
team_id="test-team",
)
# Patch at the source modules
with patch(
"litellm.proxy.agent_endpoints.a2a_endpoints._get_agent",
return_value=mock_agent,
), patch(
"litellm.proxy.litellm_pre_call_utils.add_litellm_data_to_request",
side_effect=mock_add_litellm_data,
) as mock_add_data, patch(
"litellm.a2a.create_a2a_client",
new_callable=AsyncMock,
), patch(
"litellm.a2a.asend_message",
new_callable=AsyncMock,
return_value=mock_response,
), patch(
"litellm.proxy.proxy_server.general_settings",
{},
), patch(
"litellm.proxy.proxy_server.proxy_config",
MagicMock(),
), patch(
"litellm.proxy.proxy_server.version",
"1.0.0",
):
from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a
mock_fastapi_response = MagicMock()
result = await invoke_agent_a2a(
agent_id="test-agent",
request=mock_request,
fastapi_response=mock_fastapi_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify add_litellm_data_to_request was called
mock_add_data.assert_called_once()
# Verify model and custom_llm_provider were set
assert captured_data.get("model") == "a2a_agent/Test Agent"
assert captured_data.get("custom_llm_provider") == "a2a_agent"
# Verify proxy_server_request was added
assert "proxy_server_request" in captured_data
assert captured_data["proxy_server_request"]["method"] == "POST"

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB

View File

@@ -1,4 +1,6 @@
export enum Providers {
A2A_Agent = "A2A Agent",
AIML = "AI/ML API",
Bedrock = "Amazon Bedrock",
Anthropic = "Anthropic",
@@ -43,6 +45,7 @@ export enum Providers {
}
export const provider_map: Record<string, string> = {
A2A_Agent: "a2a_agent",
AIML: "aiml",
OpenAI: "openai",
OpenAI_Text: "text-completion-openai",
@@ -89,6 +92,7 @@ export const provider_map: Record<string, string> = {
const asset_logos_folder = "../ui/assets/logos/";
export const providerLogoMap: Record<string, string> = {
[Providers.A2A_Agent]: `${asset_logos_folder}a2a_agent.png`,
[Providers.AIML]: `${asset_logos_folder}aiml_api.svg`,
[Providers.Anthropic]: `${asset_logos_folder}anthropic.svg`,
[Providers.AssemblyAI]: `${asset_logos_folder}assemblyai_small.png`,