From 8097fafc05f6367b0a34e46a3e25e86f0b353fc6 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Fri, 14 Nov 2025 18:23:30 -0800 Subject: [PATCH] Agents - support agent registration + discovery (A2A spec) (#16615) * fix: initial commit adding types * refactor: refactor to include agent registry * feat(agents/): endpoints.py working endpoint for agent discovery * feat(agent_endpoints/endpoints.py): add permission management logic to agents endpoint * feat: public endpoint for showing publicly discoverable agents * feat: make /public/agent_hub discoverable * feat(agent_endpoints/endpoints.py): working create agent endpoint adds dynamic agent registration to the proxy * feat: working crud endpoints * feat: working multi-instance create/delete agents * feat(migration.sql): add migration for agents table --- .../20251114182247_agents_table/migration.sql | 20 + .../litellm_proxy_extras/schema.prisma | 15 +- .../index.html} | 0 .../proxy/_experimental/out/guardrails.html | 1 - .../out/{logs.html => logs/index.html} | 0 .../{model-hub.html => model-hub/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../proxy/_experimental/out/onboarding.html | 1 - .../index.html} | 0 .../out/{teams.html => teams/index.html} | 0 .../{test-key.html => test-key/index.html} | 0 .../out/{usage.html => usage/index.html} | 0 .../out/{users.html => users/index.html} | 0 .../index.html} | 0 litellm/proxy/_new_secret_config.yaml | 21 + litellm/proxy/_types.py | 85 +++-- .../proxy/agent_endpoints/agent_registry.py | 239 ++++++++++++ litellm/proxy/agent_endpoints/endpoints.py | 358 ++++++++++++++++++ litellm/proxy/proxy_server.py | 44 ++- .../public_endpoints/public_endpoints.py | 14 + litellm/proxy/schema.prisma | 13 + litellm/types/agents.py | 186 +++++++++ pyrightconfig.json | 2 +- schema.prisma | 13 + 25 files changed, 962 insertions(+), 50 deletions(-) create mode 100644 litellm-proxy-extras/litellm_proxy_extras/migrations/20251114182247_agents_table/migration.sql rename litellm/proxy/_experimental/out/{api-reference.html => api-reference/index.html} (100%) delete mode 100644 litellm/proxy/_experimental/out/guardrails.html rename litellm/proxy/_experimental/out/{logs.html => logs/index.html} (100%) rename litellm/proxy/_experimental/out/{model-hub.html => model-hub/index.html} (100%) rename litellm/proxy/_experimental/out/{model_hub_table.html => model_hub_table/index.html} (100%) rename litellm/proxy/_experimental/out/{models-and-endpoints.html => models-and-endpoints/index.html} (100%) delete mode 100644 litellm/proxy/_experimental/out/onboarding.html rename litellm/proxy/_experimental/out/{organizations.html => organizations/index.html} (100%) rename litellm/proxy/_experimental/out/{teams.html => teams/index.html} (100%) rename litellm/proxy/_experimental/out/{test-key.html => test-key/index.html} (100%) rename litellm/proxy/_experimental/out/{usage.html => usage/index.html} (100%) rename litellm/proxy/_experimental/out/{users.html => users/index.html} (100%) rename litellm/proxy/_experimental/out/{virtual-keys.html => virtual-keys/index.html} (100%) create mode 100644 litellm/proxy/agent_endpoints/agent_registry.py create mode 100644 litellm/proxy/agent_endpoints/endpoints.py create mode 100644 litellm/types/agents.py diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20251114182247_agents_table/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20251114182247_agents_table/migration.sql new file mode 100644 index 0000000000..ef9c8103b4 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20251114182247_agents_table/migration.sql @@ -0,0 +1,20 @@ +-- AlterTable +ALTER TABLE "LiteLLM_DailyTagSpend" ADD COLUMN "request_id" TEXT; + +-- CreateTable +CREATE TABLE "LiteLLM_AgentsTable" ( + "agent_id" TEXT NOT NULL, + "agent_name" TEXT NOT NULL, + "litellm_params" JSONB, + "agent_card_params" JSONB NOT NULL, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "created_by" TEXT NOT NULL, + "updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_by" TEXT NOT NULL, + + CONSTRAINT "LiteLLM_AgentsTable_pkey" PRIMARY KEY ("agent_id") +); + +-- CreateIndex +CREATE UNIQUE INDEX "LiteLLM_AgentsTable_agent_name_key" ON "LiteLLM_AgentsTable"("agent_name"); + diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 8890456112..d6b7cebbd1 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable { updated_by String } + +// Agents on proxy +model LiteLLM_AgentsTable { + agent_id String @id @default(uuid()) + agent_name String @unique + litellm_params Json? + agent_card_params Json + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + model LiteLLM_OrganizationTable { organization_id String @id @default(uuid()) organization_alias String @@ -610,4 +623,4 @@ model LiteLLM_CacheConfig { cache_settings Json created_at DateTime @default(now()) updated_at DateTime @updatedAt -} +} \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/api-reference.html b/litellm/proxy/_experimental/out/api-reference/index.html similarity index 100% rename from litellm/proxy/_experimental/out/api-reference.html rename to litellm/proxy/_experimental/out/api-reference/index.html diff --git a/litellm/proxy/_experimental/out/guardrails.html b/litellm/proxy/_experimental/out/guardrails.html deleted file mode 100644 index a5f4b95a37..0000000000 --- a/litellm/proxy/_experimental/out/guardrails.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/logs.html b/litellm/proxy/_experimental/out/logs/index.html similarity index 100% rename from litellm/proxy/_experimental/out/logs.html rename to litellm/proxy/_experimental/out/logs/index.html diff --git a/litellm/proxy/_experimental/out/model-hub.html b/litellm/proxy/_experimental/out/model-hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model-hub.html rename to litellm/proxy/_experimental/out/model-hub/index.html diff --git a/litellm/proxy/_experimental/out/model_hub_table.html b/litellm/proxy/_experimental/out/model_hub_table/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub_table.html rename to litellm/proxy/_experimental/out/model_hub_table/index.html diff --git a/litellm/proxy/_experimental/out/models-and-endpoints.html b/litellm/proxy/_experimental/out/models-and-endpoints/index.html similarity index 100% rename from litellm/proxy/_experimental/out/models-and-endpoints.html rename to litellm/proxy/_experimental/out/models-and-endpoints/index.html diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 23f711cd6e..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_experimental/out/organizations.html b/litellm/proxy/_experimental/out/organizations/index.html similarity index 100% rename from litellm/proxy/_experimental/out/organizations.html rename to litellm/proxy/_experimental/out/organizations/index.html diff --git a/litellm/proxy/_experimental/out/teams.html b/litellm/proxy/_experimental/out/teams/index.html similarity index 100% rename from litellm/proxy/_experimental/out/teams.html rename to litellm/proxy/_experimental/out/teams/index.html diff --git a/litellm/proxy/_experimental/out/test-key.html b/litellm/proxy/_experimental/out/test-key/index.html similarity index 100% rename from litellm/proxy/_experimental/out/test-key.html rename to litellm/proxy/_experimental/out/test-key/index.html diff --git a/litellm/proxy/_experimental/out/usage.html b/litellm/proxy/_experimental/out/usage/index.html similarity index 100% rename from litellm/proxy/_experimental/out/usage.html rename to litellm/proxy/_experimental/out/usage/index.html diff --git a/litellm/proxy/_experimental/out/users.html b/litellm/proxy/_experimental/out/users/index.html similarity index 100% rename from litellm/proxy/_experimental/out/users.html rename to litellm/proxy/_experimental/out/users/index.html diff --git a/litellm/proxy/_experimental/out/virtual-keys.html b/litellm/proxy/_experimental/out/virtual-keys/index.html similarity index 100% rename from litellm/proxy/_experimental/out/virtual-keys.html rename to litellm/proxy/_experimental/out/virtual-keys/index.html diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 3a4cf84609..89a5e01b16 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -11,3 +11,24 @@ model_list: model: openai/gpt-4o-mini-transcribe api_key: os.environ/OPENAI_API_KEY +agent_list: + - agent_name: my_custom_agent + agent_card_params: + protocolVersion: '1.0' + name: 'Hello World Agent' + description: Just a hello world agent + url: http://localhost:9999/ + version: 1.0.0 + defaultInputModes: ['text'] + defaultOutputModes: ['text'] + capabilities: + streaming: true + skills: + - id: 'hello_world' + name: 'Returns hello world' + description: 'just returns hello world' + tags: ['hello world'] + examples: ['hi', 'hello world'] + supportsAuthenticatedExtendedCard: true + litellm_params: + make_public: true \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index a212eab076..35ec1226fd 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -509,6 +509,7 @@ class LiteLLMRoutes(enum.Enum): "/litellm/.well-known/litellm-ui-config", "/.well-known/litellm-ui-config", "/public/model_hub", + "/public/agent_hub", ] ) @@ -781,9 +782,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase): allowed_cache_controls: Optional[list] = [] config: Optional[dict] = {} permissions: Optional[dict] = {} - model_max_budget: Optional[ - dict - ] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} + model_max_budget: Optional[dict] = ( + {} + ) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {} model_config = ConfigDict(protected_namespaces=()) model_rpm_limit: Optional[dict] = None @@ -1237,12 +1238,12 @@ class NewCustomerRequest(BudgetNewRequest): blocked: bool = False # allow/disallow requests for this end-user budget_id: Optional[str] = None # give either a budget_id or max_budget spend: Optional[float] = None - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) @model_validator(mode="before") @classmethod @@ -1264,12 +1265,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase): blocked: bool = False # allow/disallow requests for this end-user max_budget: Optional[float] = None budget_id: Optional[str] = None # give either a budget_id or max_budget - allowed_model_region: Optional[ - AllowedModelRegion - ] = None # require all user requests to use models in this specific region - default_model: Optional[ - str - ] = None # if no equivalent model in allowed region - default all requests to this model + allowed_model_region: Optional[AllowedModelRegion] = ( + None # require all user requests to use models in this specific region + ) + default_model: Optional[str] = ( + None # if no equivalent model in allowed region - default all requests to this model + ) class DeleteCustomerRequest(LiteLLMPydanticObjectBase): @@ -1353,15 +1354,15 @@ class NewTeamRequest(TeamBase): ] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm model_tpm_limit: Optional[Dict[str, int]] = None - team_member_budget: Optional[ - float - ] = None # allow user to set a budget for all team members - team_member_rpm_limit: Optional[ - int - ] = None # allow user to set RPM limit for all team members - team_member_tpm_limit: Optional[ - int - ] = None # allow user to set TPM limit for all team members + team_member_budget: Optional[float] = ( + None # allow user to set a budget for all team members + ) + team_member_rpm_limit: Optional[int] = ( + None # allow user to set RPM limit for all team members + ) + team_member_tpm_limit: Optional[int] = ( + None # allow user to set TPM limit for all team members + ) team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m" allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None @@ -1445,9 +1446,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase): class AddTeamCallback(LiteLLMPydanticObjectBase): callback_name: str - callback_type: Optional[ - Literal["success", "failure", "success_and_failure"] - ] = "success_and_failure" + callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = ( + "success_and_failure" + ) callback_vars: Dict[str, str] @model_validator(mode="before") @@ -1732,9 +1733,9 @@ class ConfigList(LiteLLMPydanticObjectBase): stored_in_db: Optional[bool] field_default_value: Any premium_field: bool = False - nested_fields: Optional[ - List[FieldDetail] - ] = None # For nested dictionary or Pydantic fields + nested_fields: Optional[List[FieldDetail]] = ( + None # For nested dictionary or Pydantic fields + ) class UserHeaderMapping(LiteLLMPydanticObjectBase): @@ -2114,9 +2115,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase): budget_id: Optional[str] = None created_at: datetime updated_at: datetime - user: Optional[ - Any - ] = None # You might want to replace 'Any' with a more specific type if available + user: Optional[Any] = ( + None # You might want to replace 'Any' with a more specific type if available + ) litellm_budget_table: Optional[LiteLLM_BudgetTable] = None model_config = ConfigDict(protected_namespaces=()) @@ -3054,9 +3055,9 @@ class TeamModelDeleteRequest(BaseModel): # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str - max_budget_in_organization: Optional[ - float - ] = None # Users max budget within the organization + max_budget_in_organization: Optional[float] = ( + None # Users max budget within the organization + ) class OrganizationMemberDeleteRequest(MemberDeleteRequest): @@ -3269,9 +3270,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase): Maps provider names to their budget configs. """ - providers: Dict[ - str, ProviderBudgetResponseObject - ] = {} # Dictionary mapping provider names to their budget configurations + providers: Dict[str, ProviderBudgetResponseObject] = ( + {} + ) # Dictionary mapping provider names to their budget configurations class ProxyStateVariables(TypedDict): @@ -3405,9 +3406,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): enforce_rbac: bool = False roles_jwt_field: Optional[str] = None # v2 on role mappings role_mappings: Optional[List[RoleMapping]] = None - object_id_jwt_field: Optional[ - str - ] = None # can be either user / team, inferred from the role mapping + object_id_jwt_field: Optional[str] = ( + None # can be either user / team, inferred from the role mapping + ) scope_mappings: Optional[List[ScopeMapping]] = None enforce_scope_based_access: bool = False enforce_team_based_model_access: bool = False diff --git a/litellm/proxy/agent_endpoints/agent_registry.py b/litellm/proxy/agent_endpoints/agent_registry.py new file mode 100644 index 0000000000..345c31e07b --- /dev/null +++ b/litellm/proxy/agent_endpoints/agent_registry.py @@ -0,0 +1,239 @@ +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional + +from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.proxy.utils import PrismaClient +from litellm.types.agents import AgentConfig + + +class AgentRegistry: + def __init__(self): + self.agent_list: List[AgentConfig] = [] + + def reset_agent_list(self): + self.agent_list = [] + + def register_agent(self, agent_config: AgentConfig): + self.agent_list.append(agent_config) + + def deregister_agent(self, agent_name: str): + self.agent_list = [ + agent for agent in self.agent_list if agent.get("agent_name") != agent_name + ] + + def get_agent_list(self, agent_names: Optional[List[str]] = None): + if agent_names is not None: + return [ + agent + for agent in self.agent_list + if agent.get("agent_name") in agent_names + ] + return self.agent_list + + def get_public_agent_list(self): + public_agent_list = [] + for agent in self.agent_list: + if agent.get("litellm_params", {}).get("make_public", False) is True: + public_agent_list.append(agent) + return public_agent_list + + def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None): + if agent_config is None: + return None + + for agent_config_item in agent_config: + if not isinstance(agent_config_item, dict): + raise ValueError("agent_config must be a list of dictionaries") + + agent_name = agent_config_item.get("agent_name") + agent_card_params = agent_config_item.get("agent_card_params") + if not all([agent_name, agent_card_params]): + continue + + self.register_agent(agent_config=agent_config_item) + + def load_agents_from_db_and_config( + self, + agent_config: Optional[List[AgentConfig]] = None, + db_agents: Optional[List[Dict[str, Any]]] = None, + ): + self.reset_agent_list() + + if agent_config: + for agent_config_item in agent_config: + if not isinstance(agent_config_item, dict): + raise ValueError("agent_config must be a list of dictionaries") + + self.register_agent(agent_config=agent_config_item) + + if db_agents: + for db_agent in db_agents: + if not isinstance(db_agent, dict): + raise ValueError("db_agents must be a list of dictionaries") + + self.register_agent(agent_config=AgentConfig(**db_agent)) + return self.agent_list + + ########################################################### + ########### DB management helpers for agents ########### + ############################################################ + async def add_agent_to_db( + self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str + ) -> Dict[str, Any]: + """ + Add an agent to the database + """ + try: + agent_name = agent.get("agent_name") + + # Serialize litellm_params + litellm_params_obj: Any = agent.get("litellm_params", {}) + if hasattr(litellm_params_obj, "model_dump"): + litellm_params_dict = litellm_params_obj.model_dump() + else: + litellm_params_dict = ( + dict(litellm_params_obj) if litellm_params_obj else {} + ) + litellm_params: str = safe_dumps(litellm_params_dict) + + # Serialize agent_card_params + agent_card_params_obj: Any = agent.get("agent_card_params", {}) + if hasattr(agent_card_params_obj, "model_dump"): + agent_card_params_dict = agent_card_params_obj.model_dump() + else: + agent_card_params_dict = ( + dict(agent_card_params_obj) if agent_card_params_obj else {} + ) + agent_card_params: str = safe_dumps(agent_card_params_dict) + + # Create agent in DB + created_agent = await prisma_client.db.litellm_agentstable.create( + data={ + "agent_name": agent_name, + "litellm_params": litellm_params, + "agent_card_params": agent_card_params, + "created_by": created_by, + "updated_by": created_by, + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + + return dict(created_agent) + except Exception as e: + raise Exception(f"Error adding agent to DB: {str(e)}") + + async def delete_agent_from_db( + self, agent_id: str, prisma_client: PrismaClient + ) -> Dict[str, Any]: + """ + Delete an agent from the database + """ + try: + deleted_agent = await prisma_client.db.litellm_agentstable.delete( + where={"agent_id": agent_id} + ) + return dict(deleted_agent) + except Exception as e: + raise Exception(f"Error deleting agent from DB: {str(e)}") + + async def update_agent_in_db( + self, + agent_id: str, + agent: AgentConfig, + prisma_client: PrismaClient, + updated_by: str, + ) -> Dict[str, Any]: + """ + Update an agent in the database + """ + try: + agent_name = agent.get("agent_name") + + # Serialize litellm_params + litellm_params_obj: Any = agent.get("litellm_params", {}) + if hasattr(litellm_params_obj, "model_dump"): + litellm_params_dict = litellm_params_obj.model_dump() + else: + litellm_params_dict = ( + dict(litellm_params_obj) if litellm_params_obj else {} + ) + litellm_params: str = safe_dumps(litellm_params_dict) + + # Serialize agent_card_params + agent_card_params_obj: Any = agent.get("agent_card_params", {}) + if hasattr(agent_card_params_obj, "model_dump"): + agent_card_params_dict = agent_card_params_obj.model_dump() + else: + agent_card_params_dict = ( + dict(agent_card_params_obj) if agent_card_params_obj else {} + ) + agent_card_params: str = safe_dumps(agent_card_params_dict) + + # Update agent in DB + updated_agent = await prisma_client.db.litellm_agentstable.update( + where={"agent_id": agent_id}, + data={ + "agent_name": agent_name, + "litellm_params": litellm_params, + "agent_card_params": agent_card_params, + "updated_by": updated_by, + "updated_at": datetime.now(timezone.utc), + }, + ) + + return dict(updated_agent) + except Exception as e: + raise Exception(f"Error updating agent in DB: {str(e)}") + + @staticmethod + async def get_all_agents_from_db( + prisma_client: PrismaClient, + ) -> List[Dict[str, Any]]: + """ + Get all agents from the database + """ + try: + agents_from_db = await prisma_client.db.litellm_agentstable.find_many( + order={"created_at": "desc"}, + ) + + agents: List[Dict[str, Any]] = [] + for agent in agents_from_db: + agents.append(dict(agent)) + + return agents + except Exception as e: + raise Exception(f"Error getting agents from DB: {str(e)}") + + def get_agent_by_id( + self, + agent_id: str, + ) -> Optional[Dict[str, Any]]: + """ + Get an agent by its ID from the database + """ + try: + for agent in self.agent_list: + if agent.get("agent_id") == agent_id: + return dict(agent) + + return None + except Exception as e: + raise Exception(f"Error getting agent from DB: {str(e)}") + + def get_agent_by_name(self, agent_name: str) -> Optional[Dict[str, Any]]: + """ + Get an agent by its name from the database + """ + try: + for agent in self.agent_list: + if agent.get("agent_name") == agent_name: + return dict(agent) + + return None + except Exception as e: + raise Exception(f"Error getting agent from DB: {str(e)}") + + +global_agent_registry = AgentRegistry() diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py new file mode 100644 index 0000000000..87c3cb44c4 --- /dev/null +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -0,0 +1,358 @@ +""" +Agent endpoints for registering + discovering agents via LiteLLM. + +Follows the A2A Spec. + +1. Register an agent via POST `/v1/agents` +2. Discover agents via GET `/v1/agents` +3. Get specific agent via GET `/v1/agents/{agent_id}` +""" + +from typing import Any, List + +from fastapi import APIRouter, Depends, HTTPException, Request + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.agents import AgentConfig, AgentResponse + +router = APIRouter() + + +@router.get( + "/v1/agents", + tags=["[beta] Agents"], + dependencies=[Depends(user_api_key_auth)], + response_model=List[AgentConfig], +) +async def get_agents( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth +): + """ + Example usage: + ``` + curl -X GET "http://localhost:4000/v1/agents" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer your-key" \ + ``` + + Returns: List[AgentConfig] + + """ + from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry + + try: + if ( + user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value + ): + return global_agent_registry.get_agent_list() + key_agents = user_api_key_dict.metadata.get("agents") + _team_metadata = user_api_key_dict.team_metadata or {} + team_agents = _team_metadata.get("agents") + if key_agents is not None: + return global_agent_registry.get_agent_list(agent_names=key_agents) + if team_agents is not None: + return global_agent_registry.get_agent_list(agent_names=team_agents) + return [] + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format( + str(e) + ) + ) + raise HTTPException( + status_code=500, detail={"error": f"Internal server error: {str(e)}"} + ) + + +#### CRUD ENDPOINTS FOR AGENTS #### + +from litellm.proxy.agent_endpoints.agent_registry import ( + global_agent_registry as AGENT_REGISTRY, +) + + +@router.post( + "/v1/agents", + tags=["[beta] Agents"], + dependencies=[Depends(user_api_key_auth)], + response_model=AgentResponse, +) +async def create_agent( + request: AgentConfig, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Create a new agent + + Example Request: + ```bash + curl -X POST "http://localhost:4000/agents" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "agent": { + "agent_name": "my-custom-agent", + "agent_card_params": { + "protocolVersion": "1.0", + "name": "Hello World Agent", + "description": "Just a hello world agent", + "url": "http://localhost:9999/", + "version": "1.0.0", + "defaultInputModes": ["text"], + "defaultOutputModes": ["text"], + "capabilities": { + "streaming": true + }, + "skills": [ + { + "id": "hello_world", + "name": "Returns hello world", + "description": "just returns hello world", + "tags": ["hello world"], + "examples": ["hi", "hello world"] + } + ] + }, + "litellm_params": { + "make_public": true + } + } + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Prisma client not initialized") + + try: + # Get the user ID from the API key auth + created_by = user_api_key_dict.user_id or "unknown" + + # check for naming conflicts + existing_agent = AGENT_REGISTRY.get_agent_by_name( + agent_name=request.get("agent_name") # type: ignore + ) + if existing_agent is not None: + raise HTTPException( + status_code=400, + detail=f"Agent with name {request.get('agent_name')} already exists", + ) + + result = await AGENT_REGISTRY.add_agent_to_db( + agent=request, prisma_client=prisma_client, created_by=created_by + ) + + agent_name = result.get("agent_name", "Unknown") + agent_id = result.get("agent_id", "Unknown") + + # Also register in memory + try: + AGENT_REGISTRY.register_agent(agent_config=request) + verbose_proxy_logger.info( + f"Successfully registered agent '{agent_name}' (ID: {agent_id}) in memory" + ) + except Exception as reg_error: + verbose_proxy_logger.warning( + f"Failed to register agent '{agent_name}' (ID: {agent_id}) in memory: {reg_error}" + ) + + return AgentResponse(**result) + + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error adding agent to db: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get( + "/v1/agents/{agent_id}", + tags=["[beta] Agents"], + dependencies=[Depends(user_api_key_auth)], + response_model=AgentResponse, +) +async def get_agent_by_id(agent_id: str): + """ + Get a specific agent by ID + + Example Request: + ```bash + curl -X GET "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Prisma client not initialized") + + try: + agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id) + if agent is None: + agent = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": agent_id} + ) + if agent is not None: + agent = dict(agent) + + if agent is None: + raise HTTPException( + status_code=404, detail=f"Agent with ID {agent_id} not found" + ) + + return AgentResponse(**agent) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error getting agent from db: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put( + "/v1/agents/{agent_id}", + tags=["[beta] Agents"], + dependencies=[Depends(user_api_key_auth)], + response_model=AgentResponse, +) +async def update_agent( + agent_id: str, + request: AgentConfig, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Update an existing agent + + Example Request: + ```bash + curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{ + "agent": { + "agent_name": "updated-agent", + "agent_card_params": { + "protocolVersion": "1.0", + "name": "Updated Agent", + "description": "Updated description", + "url": "http://localhost:9999/", + "version": "1.1.0", + "defaultInputModes": ["text"], + "defaultOutputModes": ["text"], + "capabilities": { + "streaming": true + }, + "skills": [] + }, + "litellm_params": { + "make_public": false + } + } + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException( + status_code=500, detail=CommonProxyErrors.db_not_connected_error.value + ) + + try: + # Check if agent exists + existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": agent_id} + ) + if existing_agent is not None: + existing_agent = dict(existing_agent) + + if existing_agent is None: + raise HTTPException( + status_code=404, detail=f"Agent with ID {agent_id} not found" + ) + + # Get the user ID from the API key auth + updated_by = user_api_key_dict.user_id or "unknown" + + result = await AGENT_REGISTRY.update_agent_in_db( + agent_id=agent_id, + agent=request, + prisma_client=prisma_client, + updated_by=updated_by, + ) + + # deregister in memory + AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore + # register in memory + AGENT_REGISTRY.register_agent(agent_config=request) + + verbose_proxy_logger.info( + f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory" + ) + + return AgentResponse(**result) + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error updating agent: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete( + "/v1/agents/{agent_id}", + tags=["Agents"], + dependencies=[Depends(user_api_key_auth)], +) +async def delete_agent(agent_id: str): + """ + Delete an agent + + Example Request: + ```bash + curl -X DELETE "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\ + -H "Authorization: Bearer " + ``` + + Example Response: + ```json + { + "message": "Agent 123e4567-e89b-12d3-a456-426614174000 deleted successfully" + } + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail="Prisma client not initialized") + + try: + # Check if agent exists + existing_agent = await prisma_client.db.litellm_agentstable.find_unique( + where={"agent_id": agent_id} + ) + if existing_agent is not None: + existing_agent = dict[Any, Any](existing_agent) + + if existing_agent is None: + raise HTTPException( + status_code=404, detail=f"Agent with ID {agent_id} not found in DB." + ) + + await AGENT_REGISTRY.delete_agent_from_db( + agent_id=agent_id, prisma_client=prisma_client + ) + + AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore + + return {"message": f"Agent {agent_id} deleted successfully"} + except HTTPException: + raise + except Exception as e: + verbose_proxy_logger.exception(f"Error deleting agent: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 06964d1cf8..226b5c1179 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -40,6 +40,7 @@ from litellm.constants import ( LITELLM_SETTINGS_SAFE_DB_OVERRIDES, ) from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.proxy.common_utils.callback_utils import normalize_callback_names from litellm.types.utils import ( ModelResponse, ModelResponseStream, @@ -48,8 +49,6 @@ from litellm.types.utils import ( ) from litellm.utils import load_credentials_from_list -from litellm.proxy.common_utils.callback_utils import normalize_callback_names - if TYPE_CHECKING: from aiohttp import ClientSession from opentelemetry.trace import Span as _Span @@ -175,6 +174,8 @@ from litellm.proxy._experimental.mcp_server.tool_registry import ( global_mcp_tool_registry, ) from litellm.proxy._types import * +from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry +from litellm.proxy.agent_endpoints.endpoints import router as agent_endpoints_router from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) @@ -471,6 +472,8 @@ from fastapi.security import OAuth2PasswordBearer from fastapi.security.api_key import APIKeyHeader from fastapi.staticfiles import StaticFiles +from litellm.types.agents import AgentConfig + # import enterprise folder enterprise_router = APIRouter() try: @@ -1061,6 +1064,7 @@ callback_settings: dict = {} log_file = "api_log.json" worker_config = None master_key: Optional[str] = None +config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None shared_aiohttp_session: Optional["ClientSession"] = ( @@ -2585,6 +2589,11 @@ class ProxyConfig: if mcp_tools_config: global_mcp_tool_registry.load_tools_from_config(mcp_tools_config) + ## AGENTS + agent_config = config.get("agent_list", None) + if agent_config: + global_agent_registry.load_agents_from_config(agent_config) # type: ignore + mcp_servers_config = config.get("mcp_servers", None) if mcp_servers_config: from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( @@ -2821,7 +2830,9 @@ class ProxyConfig: for k, v in _litellm_params.items(): if isinstance(v, str): # decrypt value - returns original value if decryption fails or no key is set - _value = decrypt_value_helper(value=v, key=k, return_original_value=True) + _value = decrypt_value_helper( + value=v, key=k, return_original_value=True + ) _litellm_params[k] = _value _litellm_params = LiteLLM_Params(**_litellm_params) @@ -3414,6 +3425,9 @@ class ProxyConfig: if self._should_load_db_object(object_type="mcp"): await self._init_mcp_servers_in_db() + if self._should_load_db_object(object_type="agents"): + await self._init_agents_in_db(prisma_client=prisma_client) + if self._should_load_db_object(object_type="pass_through_endpoints"): await self._init_pass_through_endpoints_in_db() @@ -3682,6 +3696,25 @@ class ProxyConfig: ) ) + async def _init_agents_in_db(self, prisma_client: PrismaClient): + from litellm.proxy.agent_endpoints.agent_registry import ( + global_agent_registry as AGENT_REGISTRY, + ) + + try: + db_agents = await AGENT_REGISTRY.get_all_agents_from_db( + prisma_client=prisma_client + ) + AGENT_REGISTRY.load_agents_from_db_and_config( + db_agents=db_agents, agent_config=config_agents + ) + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.proxy_server.py::ProxyConfig:_init_agents_in_db - {}".format( + str(e) + ) + ) + async def _init_search_tools_in_db(self, prisma_client: PrismaClient): """ Initialize search tools from database into the router on startup. @@ -8963,7 +8996,9 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915 if isinstance( config["litellm_settings"]["success_callback"], list ) and isinstance(updated_litellm_settings["success_callback"], list): - updated_success_callbacks_normalized = normalize_callback_names(updated_litellm_settings["success_callback"]) + updated_success_callbacks_normalized = normalize_callback_names( + updated_litellm_settings["success_callback"] + ) combined_success_callback = ( config["litellm_settings"]["success_callback"] + updated_success_callbacks_normalized @@ -10097,6 +10132,7 @@ app.include_router(cache_settings_router) app.include_router(user_agent_analytics_router) app.include_router(enterprise_router) app.include_router(ui_discovery_endpoints_router) +app.include_router(agent_endpoints_router) ######################################################## # MCP Server ######################################################## diff --git a/litellm/proxy/public_endpoints/public_endpoints.py b/litellm/proxy/public_endpoints/public_endpoints.py index 8c1e6b74b3..0160021846 100644 --- a/litellm/proxy/public_endpoints/public_endpoints.py +++ b/litellm/proxy/public_endpoints/public_endpoints.py @@ -7,6 +7,7 @@ from litellm.proxy.public_endpoints.provider_create_metadata import ( get_provider_create_metadata, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.types.agents import AgentCard from litellm.types.proxy.management_endpoints.model_management_endpoints import ( ModelGroupInfoProxy, ) @@ -45,6 +46,19 @@ async def public_model_hub(): return model_groups +@router.get( + "/public/agent_hub", + tags=["[beta] Agents", "public"], + dependencies=[Depends(user_api_key_auth)], + response_model=List[AgentCard], +) +async def get_agents(): + from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry + + agents = global_agent_registry.get_public_agent_list() + return [agent.get("agent_card_params") for agent in agents] + + @router.get( "/public/model_hub/info", tags=["public", "model management"], diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 51e6ea9454..d6b7cebbd1 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable { updated_by String } + +// Agents on proxy +model LiteLLM_AgentsTable { + agent_id String @id @default(uuid()) + agent_name String @unique + litellm_params Json? + agent_card_params Json + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + model LiteLLM_OrganizationTable { organization_id String @id @default(uuid()) organization_alias String diff --git a/litellm/types/agents.py b/litellm/types/agents.py new file mode 100644 index 0000000000..b9be640fd7 --- /dev/null +++ b/litellm/types/agents.py @@ -0,0 +1,186 @@ +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel +from typing_extensions import Required, TypedDict + + +# AgentProvider +class AgentProvider(TypedDict, total=False): + """Represents the service provider of an agent.""" + + organization: str # required + url: str # required + + +# AgentExtension +class AgentExtension(TypedDict, total=False): + """A declaration of a protocol extension supported by an Agent.""" + + uri: str # required + description: Optional[str] + required: Optional[bool] + params: Optional[Dict[str, Any]] + + +# AgentCapabilities +class AgentCapabilities(TypedDict, total=False): + """Defines optional capabilities supported by an agent.""" + + streaming: Optional[bool] + pushNotifications: Optional[bool] + stateTransitionHistory: Optional[bool] + extensions: Optional[List[AgentExtension]] + + +# SecurityScheme types +class SecuritySchemeBase(TypedDict, total=False): + """Base properties shared by all security scheme objects.""" + + description: Optional[str] + + +class APIKeySecurityScheme(SecuritySchemeBase): + """Defines a security scheme using an API key.""" + + type: Literal["apiKey"] + in_: Literal["query", "header", "cookie"] # using in_ to avoid Python keyword + name: str + + +class HTTPAuthSecurityScheme(SecuritySchemeBase): + """Defines a security scheme using HTTP authentication.""" + + type: Literal["http"] + scheme: str + bearerFormat: Optional[str] + + +class MutualTLSSecurityScheme(SecuritySchemeBase): + """Defines a security scheme using mTLS authentication.""" + + type: Literal["mutualTLS"] + + +class OAuthFlows(TypedDict, total=False): + """Defines the configuration for the supported OAuth 2.0 flows.""" + + authorizationCode: Optional[Dict[str, Any]] + clientCredentials: Optional[Dict[str, Any]] + implicit: Optional[Dict[str, Any]] + password: Optional[Dict[str, Any]] + + +class OAuth2SecurityScheme(SecuritySchemeBase): + """Defines a security scheme using OAuth 2.0.""" + + type: Literal["oauth2"] + flows: OAuthFlows + oauth2MetadataUrl: Optional[str] + + +class OpenIdConnectSecurityScheme(SecuritySchemeBase): + """Defines a security scheme using OpenID Connect.""" + + type: Literal["openIdConnect"] + openIdConnectUrl: str + + +# Union of all security schemes +SecurityScheme = Union[ + APIKeySecurityScheme, + HTTPAuthSecurityScheme, + OAuth2SecurityScheme, + OpenIdConnectSecurityScheme, + MutualTLSSecurityScheme, +] + + +# AgentSkill +class AgentSkill(TypedDict, total=False): + """Represents a distinct capability or function that an agent can perform.""" + + id: str # required + name: str # required + description: str # required + tags: List[str] # required + examples: Optional[List[str]] + inputModes: Optional[List[str]] + outputModes: Optional[List[str]] + security: Optional[List[Dict[str, List[str]]]] + + +# AgentInterface +class AgentInterface(TypedDict, total=False): + """Declares a combination of a target URL and a transport protocol.""" + + url: str # required + transport: str # required (TransportProtocol | string) + + +# AgentCardSignature +class AgentCardSignature(TypedDict, total=False): + """Represents a JWS signature of an AgentCard.""" + + protected: str # required + signature: str # required + header: Optional[Dict[str, Any]] + + +# AgentCard +class AgentCard(TypedDict, total=False): + """ + The AgentCard is a self-describing manifest for an agent. + It provides essential metadata including the agent's identity, capabilities, + skills, supported communication methods, and security requirements. + """ + + # Required fields + protocolVersion: str + name: str + description: str + url: str + version: str + capabilities: AgentCapabilities + defaultInputModes: List[str] + defaultOutputModes: List[str] + skills: List[AgentSkill] + + # Optional fields + preferredTransport: Optional[str] + additionalInterfaces: Optional[List[AgentInterface]] + iconUrl: Optional[str] + provider: Optional[AgentProvider] + documentationUrl: Optional[str] + securitySchemes: Optional[Dict[str, SecurityScheme]] + security: Optional[List[Dict[str, List[str]]]] + supportsAuthenticatedExtendedCard: Optional[bool] + signatures: Optional[List[AgentCardSignature]] + + +class AgentLitellmParams(TypedDict): + make_public: bool + + +class AgentConfig(TypedDict, total=False): + agent_name: Required[str] + agent_card_params: Required[AgentCard] + litellm_params: AgentLitellmParams + + +# Request/Response models for CRUD endpoints + + +class AgentResponse(BaseModel): + agent_id: str + agent_name: str + litellm_params: Optional[Dict[str, Any]] = None + agent_card_params: Dict[str, Any] + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + created_by: Optional[str] = None + updated_by: Optional[str] = None + + +class ListAgentsResponse(BaseModel): + agents: List[AgentResponse] diff --git a/pyrightconfig.json b/pyrightconfig.json index 9a43abda78..f930e44d30 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,6 +1,6 @@ { "ignore": [], - "exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"], + "exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py", "litellm/proxy/_types.py"], "reportMissingImports": false, "reportPrivateImportUsage": false } diff --git a/schema.prisma b/schema.prisma index 51e6ea9454..d6b7cebbd1 100644 --- a/schema.prisma +++ b/schema.prisma @@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable { updated_by String } + +// Agents on proxy +model LiteLLM_AgentsTable { + agent_id String @id @default(uuid()) + agent_name String @unique + litellm_params Json? + agent_card_params Json + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String +} + model LiteLLM_OrganizationTable { organization_id String @id @default(uuid()) organization_alias String