mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
Agents - support agent registration + discovery (A2A spec) (#16615)
* fix: initial commit adding types * refactor: refactor to include agent registry * feat(agents/): endpoints.py working endpoint for agent discovery * feat(agent_endpoints/endpoints.py): add permission management logic to agents endpoint * feat: public endpoint for showing publicly discoverable agents * feat: make /public/agent_hub discoverable * feat(agent_endpoints/endpoints.py): working create agent endpoint adds dynamic agent registration to the proxy * feat: working crud endpoints * feat: working multi-instance create/delete agents * feat(migration.sql): add migration for agents table
This commit is contained in:
@@ -0,0 +1,20 @@
|
|||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "LiteLLM_DailyTagSpend" ADD COLUMN "request_id" TEXT;
|
||||||
|
|
||||||
|
-- CreateTable
|
||||||
|
CREATE TABLE "LiteLLM_AgentsTable" (
|
||||||
|
"agent_id" TEXT NOT NULL,
|
||||||
|
"agent_name" TEXT NOT NULL,
|
||||||
|
"litellm_params" JSONB,
|
||||||
|
"agent_card_params" JSONB NOT NULL,
|
||||||
|
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"created_by" TEXT NOT NULL,
|
||||||
|
"updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
"updated_by" TEXT NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT "LiteLLM_AgentsTable_pkey" PRIMARY KEY ("agent_id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- CreateIndex
|
||||||
|
CREATE UNIQUE INDEX "LiteLLM_AgentsTable_agent_name_key" ON "LiteLLM_AgentsTable"("agent_name");
|
||||||
|
|
||||||
@@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable {
|
|||||||
updated_by String
|
updated_by String
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Agents on proxy
|
||||||
|
model LiteLLM_AgentsTable {
|
||||||
|
agent_id String @id @default(uuid())
|
||||||
|
agent_name String @unique
|
||||||
|
litellm_params Json?
|
||||||
|
agent_card_params Json
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
model LiteLLM_OrganizationTable {
|
model LiteLLM_OrganizationTable {
|
||||||
organization_id String @id @default(uuid())
|
organization_id String @id @default(uuid())
|
||||||
organization_alias String
|
organization_alias String
|
||||||
@@ -610,4 +623,4 @@ model LiteLLM_CacheConfig {
|
|||||||
cache_settings Json
|
cache_settings Json
|
||||||
created_at DateTime @default(now())
|
created_at DateTime @default(now())
|
||||||
updated_at DateTime @updatedAt
|
updated_at DateTime @updatedAt
|
||||||
}
|
}
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -11,3 +11,24 @@ model_list:
|
|||||||
model: openai/gpt-4o-mini-transcribe
|
model: openai/gpt-4o-mini-transcribe
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
|
||||||
|
agent_list:
|
||||||
|
- agent_name: my_custom_agent
|
||||||
|
agent_card_params:
|
||||||
|
protocolVersion: '1.0'
|
||||||
|
name: 'Hello World Agent'
|
||||||
|
description: Just a hello world agent
|
||||||
|
url: http://localhost:9999/
|
||||||
|
version: 1.0.0
|
||||||
|
defaultInputModes: ['text']
|
||||||
|
defaultOutputModes: ['text']
|
||||||
|
capabilities:
|
||||||
|
streaming: true
|
||||||
|
skills:
|
||||||
|
- id: 'hello_world'
|
||||||
|
name: 'Returns hello world'
|
||||||
|
description: 'just returns hello world'
|
||||||
|
tags: ['hello world']
|
||||||
|
examples: ['hi', 'hello world']
|
||||||
|
supportsAuthenticatedExtendedCard: true
|
||||||
|
litellm_params:
|
||||||
|
make_public: true
|
||||||
@@ -509,6 +509,7 @@ class LiteLLMRoutes(enum.Enum):
|
|||||||
"/litellm/.well-known/litellm-ui-config",
|
"/litellm/.well-known/litellm-ui-config",
|
||||||
"/.well-known/litellm-ui-config",
|
"/.well-known/litellm-ui-config",
|
||||||
"/public/model_hub",
|
"/public/model_hub",
|
||||||
|
"/public/agent_hub",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -781,9 +782,9 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
|
|||||||
allowed_cache_controls: Optional[list] = []
|
allowed_cache_controls: Optional[list] = []
|
||||||
config: Optional[dict] = {}
|
config: Optional[dict] = {}
|
||||||
permissions: Optional[dict] = {}
|
permissions: Optional[dict] = {}
|
||||||
model_max_budget: Optional[
|
model_max_budget: Optional[dict] = (
|
||||||
dict
|
{}
|
||||||
] = {} # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
) # {"gpt-4": 5.0, "gpt-3.5-turbo": 5.0}, defaults to {}
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
model_rpm_limit: Optional[dict] = None
|
model_rpm_limit: Optional[dict] = None
|
||||||
@@ -1237,12 +1238,12 @@ class NewCustomerRequest(BudgetNewRequest):
|
|||||||
blocked: bool = False # allow/disallow requests for this end-user
|
blocked: bool = False # allow/disallow requests for this end-user
|
||||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
allowed_model_region: Optional[
|
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||||
AllowedModelRegion
|
None # require all user requests to use models in this specific region
|
||||||
] = None # require all user requests to use models in this specific region
|
)
|
||||||
default_model: Optional[
|
default_model: Optional[str] = (
|
||||||
str
|
None # if no equivalent model in allowed region - default all requests to this model
|
||||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -1264,12 +1265,12 @@ class UpdateCustomerRequest(LiteLLMPydanticObjectBase):
|
|||||||
blocked: bool = False # allow/disallow requests for this end-user
|
blocked: bool = False # allow/disallow requests for this end-user
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||||
allowed_model_region: Optional[
|
allowed_model_region: Optional[AllowedModelRegion] = (
|
||||||
AllowedModelRegion
|
None # require all user requests to use models in this specific region
|
||||||
] = None # require all user requests to use models in this specific region
|
)
|
||||||
default_model: Optional[
|
default_model: Optional[str] = (
|
||||||
str
|
None # if no equivalent model in allowed region - default all requests to this model
|
||||||
] = None # if no equivalent model in allowed region - default all requests to this model
|
)
|
||||||
|
|
||||||
|
|
||||||
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
class DeleteCustomerRequest(LiteLLMPydanticObjectBase):
|
||||||
@@ -1353,15 +1354,15 @@ class NewTeamRequest(TeamBase):
|
|||||||
] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm
|
] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm
|
||||||
|
|
||||||
model_tpm_limit: Optional[Dict[str, int]] = None
|
model_tpm_limit: Optional[Dict[str, int]] = None
|
||||||
team_member_budget: Optional[
|
team_member_budget: Optional[float] = (
|
||||||
float
|
None # allow user to set a budget for all team members
|
||||||
] = None # allow user to set a budget for all team members
|
)
|
||||||
team_member_rpm_limit: Optional[
|
team_member_rpm_limit: Optional[int] = (
|
||||||
int
|
None # allow user to set RPM limit for all team members
|
||||||
] = None # allow user to set RPM limit for all team members
|
)
|
||||||
team_member_tpm_limit: Optional[
|
team_member_tpm_limit: Optional[int] = (
|
||||||
int
|
None # allow user to set TPM limit for all team members
|
||||||
] = None # allow user to set TPM limit for all team members
|
)
|
||||||
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
|
team_member_key_duration: Optional[str] = None # e.g. "1d", "1w", "1m"
|
||||||
allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None
|
allowed_vector_store_indexes: Optional[List[AllowedVectorStoreIndexItem]] = None
|
||||||
|
|
||||||
@@ -1445,9 +1446,9 @@ class BlockKeyRequest(LiteLLMPydanticObjectBase):
|
|||||||
|
|
||||||
class AddTeamCallback(LiteLLMPydanticObjectBase):
|
class AddTeamCallback(LiteLLMPydanticObjectBase):
|
||||||
callback_name: str
|
callback_name: str
|
||||||
callback_type: Optional[
|
callback_type: Optional[Literal["success", "failure", "success_and_failure"]] = (
|
||||||
Literal["success", "failure", "success_and_failure"]
|
"success_and_failure"
|
||||||
] = "success_and_failure"
|
)
|
||||||
callback_vars: Dict[str, str]
|
callback_vars: Dict[str, str]
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@@ -1732,9 +1733,9 @@ class ConfigList(LiteLLMPydanticObjectBase):
|
|||||||
stored_in_db: Optional[bool]
|
stored_in_db: Optional[bool]
|
||||||
field_default_value: Any
|
field_default_value: Any
|
||||||
premium_field: bool = False
|
premium_field: bool = False
|
||||||
nested_fields: Optional[
|
nested_fields: Optional[List[FieldDetail]] = (
|
||||||
List[FieldDetail]
|
None # For nested dictionary or Pydantic fields
|
||||||
] = None # For nested dictionary or Pydantic fields
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserHeaderMapping(LiteLLMPydanticObjectBase):
|
class UserHeaderMapping(LiteLLMPydanticObjectBase):
|
||||||
@@ -2114,9 +2115,9 @@ class LiteLLM_OrganizationMembershipTable(LiteLLMPydanticObjectBase):
|
|||||||
budget_id: Optional[str] = None
|
budget_id: Optional[str] = None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
user: Optional[
|
user: Optional[Any] = (
|
||||||
Any
|
None # You might want to replace 'Any' with a more specific type if available
|
||||||
] = None # You might want to replace 'Any' with a more specific type if available
|
)
|
||||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
@@ -3054,9 +3055,9 @@ class TeamModelDeleteRequest(BaseModel):
|
|||||||
# Organization Member Requests
|
# Organization Member Requests
|
||||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||||
organization_id: str
|
organization_id: str
|
||||||
max_budget_in_organization: Optional[
|
max_budget_in_organization: Optional[float] = (
|
||||||
float
|
None # Users max budget within the organization
|
||||||
] = None # Users max budget within the organization
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
|
class OrganizationMemberDeleteRequest(MemberDeleteRequest):
|
||||||
@@ -3269,9 +3270,9 @@ class ProviderBudgetResponse(LiteLLMPydanticObjectBase):
|
|||||||
Maps provider names to their budget configs.
|
Maps provider names to their budget configs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
providers: Dict[
|
providers: Dict[str, ProviderBudgetResponseObject] = (
|
||||||
str, ProviderBudgetResponseObject
|
{}
|
||||||
] = {} # Dictionary mapping provider names to their budget configurations
|
) # Dictionary mapping provider names to their budget configurations
|
||||||
|
|
||||||
|
|
||||||
class ProxyStateVariables(TypedDict):
|
class ProxyStateVariables(TypedDict):
|
||||||
@@ -3405,9 +3406,9 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
|||||||
enforce_rbac: bool = False
|
enforce_rbac: bool = False
|
||||||
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
roles_jwt_field: Optional[str] = None # v2 on role mappings
|
||||||
role_mappings: Optional[List[RoleMapping]] = None
|
role_mappings: Optional[List[RoleMapping]] = None
|
||||||
object_id_jwt_field: Optional[
|
object_id_jwt_field: Optional[str] = (
|
||||||
str
|
None # can be either user / team, inferred from the role mapping
|
||||||
] = None # can be either user / team, inferred from the role mapping
|
)
|
||||||
scope_mappings: Optional[List[ScopeMapping]] = None
|
scope_mappings: Optional[List[ScopeMapping]] = None
|
||||||
enforce_scope_based_access: bool = False
|
enforce_scope_based_access: bool = False
|
||||||
enforce_team_based_model_access: bool = False
|
enforce_team_based_model_access: bool = False
|
||||||
|
|||||||
239
litellm/proxy/agent_endpoints/agent_registry.py
Normal file
239
litellm/proxy/agent_endpoints/agent_registry.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
from litellm.types.agents import AgentConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
def __init__(self):
|
||||||
|
self.agent_list: List[AgentConfig] = []
|
||||||
|
|
||||||
|
def reset_agent_list(self):
|
||||||
|
self.agent_list = []
|
||||||
|
|
||||||
|
def register_agent(self, agent_config: AgentConfig):
|
||||||
|
self.agent_list.append(agent_config)
|
||||||
|
|
||||||
|
def deregister_agent(self, agent_name: str):
|
||||||
|
self.agent_list = [
|
||||||
|
agent for agent in self.agent_list if agent.get("agent_name") != agent_name
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_agent_list(self, agent_names: Optional[List[str]] = None):
|
||||||
|
if agent_names is not None:
|
||||||
|
return [
|
||||||
|
agent
|
||||||
|
for agent in self.agent_list
|
||||||
|
if agent.get("agent_name") in agent_names
|
||||||
|
]
|
||||||
|
return self.agent_list
|
||||||
|
|
||||||
|
def get_public_agent_list(self):
|
||||||
|
public_agent_list = []
|
||||||
|
for agent in self.agent_list:
|
||||||
|
if agent.get("litellm_params", {}).get("make_public", False) is True:
|
||||||
|
public_agent_list.append(agent)
|
||||||
|
return public_agent_list
|
||||||
|
|
||||||
|
def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None):
|
||||||
|
if agent_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for agent_config_item in agent_config:
|
||||||
|
if not isinstance(agent_config_item, dict):
|
||||||
|
raise ValueError("agent_config must be a list of dictionaries")
|
||||||
|
|
||||||
|
agent_name = agent_config_item.get("agent_name")
|
||||||
|
agent_card_params = agent_config_item.get("agent_card_params")
|
||||||
|
if not all([agent_name, agent_card_params]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.register_agent(agent_config=agent_config_item)
|
||||||
|
|
||||||
|
def load_agents_from_db_and_config(
|
||||||
|
self,
|
||||||
|
agent_config: Optional[List[AgentConfig]] = None,
|
||||||
|
db_agents: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
):
|
||||||
|
self.reset_agent_list()
|
||||||
|
|
||||||
|
if agent_config:
|
||||||
|
for agent_config_item in agent_config:
|
||||||
|
if not isinstance(agent_config_item, dict):
|
||||||
|
raise ValueError("agent_config must be a list of dictionaries")
|
||||||
|
|
||||||
|
self.register_agent(agent_config=agent_config_item)
|
||||||
|
|
||||||
|
if db_agents:
|
||||||
|
for db_agent in db_agents:
|
||||||
|
if not isinstance(db_agent, dict):
|
||||||
|
raise ValueError("db_agents must be a list of dictionaries")
|
||||||
|
|
||||||
|
self.register_agent(agent_config=AgentConfig(**db_agent))
|
||||||
|
return self.agent_list
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
########### DB management helpers for agents ###########
|
||||||
|
############################################################
|
||||||
|
async def add_agent_to_db(
|
||||||
|
self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Add an agent to the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent_name = agent.get("agent_name")
|
||||||
|
|
||||||
|
# Serialize litellm_params
|
||||||
|
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||||
|
if hasattr(litellm_params_obj, "model_dump"):
|
||||||
|
litellm_params_dict = litellm_params_obj.model_dump()
|
||||||
|
else:
|
||||||
|
litellm_params_dict = (
|
||||||
|
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||||
|
)
|
||||||
|
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||||
|
|
||||||
|
# Serialize agent_card_params
|
||||||
|
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||||
|
if hasattr(agent_card_params_obj, "model_dump"):
|
||||||
|
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||||
|
else:
|
||||||
|
agent_card_params_dict = (
|
||||||
|
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||||
|
)
|
||||||
|
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||||
|
|
||||||
|
# Create agent in DB
|
||||||
|
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||||||
|
data={
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
|
"agent_card_params": agent_card_params,
|
||||||
|
"created_by": created_by,
|
||||||
|
"updated_by": created_by,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
"updated_at": datetime.now(timezone.utc),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return dict(created_agent)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error adding agent to DB: {str(e)}")
|
||||||
|
|
||||||
|
async def delete_agent_from_db(
|
||||||
|
self, agent_id: str, prisma_client: PrismaClient
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Delete an agent from the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
deleted_agent = await prisma_client.db.litellm_agentstable.delete(
|
||||||
|
where={"agent_id": agent_id}
|
||||||
|
)
|
||||||
|
return dict(deleted_agent)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error deleting agent from DB: {str(e)}")
|
||||||
|
|
||||||
|
async def update_agent_in_db(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
agent: AgentConfig,
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
updated_by: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Update an agent in the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agent_name = agent.get("agent_name")
|
||||||
|
|
||||||
|
# Serialize litellm_params
|
||||||
|
litellm_params_obj: Any = agent.get("litellm_params", {})
|
||||||
|
if hasattr(litellm_params_obj, "model_dump"):
|
||||||
|
litellm_params_dict = litellm_params_obj.model_dump()
|
||||||
|
else:
|
||||||
|
litellm_params_dict = (
|
||||||
|
dict(litellm_params_obj) if litellm_params_obj else {}
|
||||||
|
)
|
||||||
|
litellm_params: str = safe_dumps(litellm_params_dict)
|
||||||
|
|
||||||
|
# Serialize agent_card_params
|
||||||
|
agent_card_params_obj: Any = agent.get("agent_card_params", {})
|
||||||
|
if hasattr(agent_card_params_obj, "model_dump"):
|
||||||
|
agent_card_params_dict = agent_card_params_obj.model_dump()
|
||||||
|
else:
|
||||||
|
agent_card_params_dict = (
|
||||||
|
dict(agent_card_params_obj) if agent_card_params_obj else {}
|
||||||
|
)
|
||||||
|
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||||
|
|
||||||
|
# Update agent in DB
|
||||||
|
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||||||
|
where={"agent_id": agent_id},
|
||||||
|
data={
|
||||||
|
"agent_name": agent_name,
|
||||||
|
"litellm_params": litellm_params,
|
||||||
|
"agent_card_params": agent_card_params,
|
||||||
|
"updated_by": updated_by,
|
||||||
|
"updated_at": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return dict(updated_agent)
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error updating agent in DB: {str(e)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_all_agents_from_db(
|
||||||
|
prisma_client: PrismaClient,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get all agents from the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||||||
|
order={"created_at": "desc"},
|
||||||
|
)
|
||||||
|
|
||||||
|
agents: List[Dict[str, Any]] = []
|
||||||
|
for agent in agents_from_db:
|
||||||
|
agents.append(dict(agent))
|
||||||
|
|
||||||
|
return agents
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error getting agents from DB: {str(e)}")
|
||||||
|
|
||||||
|
def get_agent_by_id(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get an agent by its ID from the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for agent in self.agent_list:
|
||||||
|
if agent.get("agent_id") == agent_id:
|
||||||
|
return dict(agent)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||||
|
|
||||||
|
def get_agent_by_name(self, agent_name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get an agent by its name from the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for agent in self.agent_list:
|
||||||
|
if agent.get("agent_name") == agent_name:
|
||||||
|
return dict(agent)
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Error getting agent from DB: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
global_agent_registry = AgentRegistry()
|
||||||
358
litellm/proxy/agent_endpoints/endpoints.py
Normal file
358
litellm/proxy/agent_endpoints/endpoints.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""
|
||||||
|
Agent endpoints for registering + discovering agents via LiteLLM.
|
||||||
|
|
||||||
|
Follows the A2A Spec.
|
||||||
|
|
||||||
|
1. Register an agent via POST `/v1/agents`
|
||||||
|
2. Discover agents via GET `/v1/agents`
|
||||||
|
3. Get specific agent via GET `/v1/agents/{agent_id}`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||||
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.types.agents import AgentConfig, AgentResponse
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/agents",
|
||||||
|
tags=["[beta] Agents"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=List[AgentConfig],
|
||||||
|
)
|
||||||
|
async def get_agents(
|
||||||
|
request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), # Used for auth
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Example usage:
|
||||||
|
```
|
||||||
|
curl -X GET "http://localhost:4000/v1/agents" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer your-key" \
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns: List[AgentConfig]
|
||||||
|
|
||||||
|
"""
|
||||||
|
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||||
|
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||||
|
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
):
|
||||||
|
return global_agent_registry.get_agent_list()
|
||||||
|
key_agents = user_api_key_dict.metadata.get("agents")
|
||||||
|
_team_metadata = user_api_key_dict.team_metadata or {}
|
||||||
|
team_agents = _team_metadata.get("agents")
|
||||||
|
if key_agents is not None:
|
||||||
|
return global_agent_registry.get_agent_list(agent_names=key_agents)
|
||||||
|
if team_agents is not None:
|
||||||
|
return global_agent_registry.get_agent_list(agent_names=team_agents)
|
||||||
|
return []
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.anthropic_endpoints.count_tokens(): Exception occurred - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail={"error": f"Internal server error: {str(e)}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
#### CRUD ENDPOINTS FOR AGENTS ####
|
||||||
|
|
||||||
|
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||||
|
global_agent_registry as AGENT_REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/v1/agents",
|
||||||
|
tags=["[beta] Agents"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=AgentResponse,
|
||||||
|
)
|
||||||
|
async def create_agent(
|
||||||
|
request: AgentConfig,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new agent
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```bash
|
||||||
|
curl -X POST "http://localhost:4000/agents" \\
|
||||||
|
-H "Authorization: Bearer <your_api_key>" \\
|
||||||
|
-H "Content-Type: application/json" \\
|
||||||
|
-d '{
|
||||||
|
"agent": {
|
||||||
|
"agent_name": "my-custom-agent",
|
||||||
|
"agent_card_params": {
|
||||||
|
"protocolVersion": "1.0",
|
||||||
|
"name": "Hello World Agent",
|
||||||
|
"description": "Just a hello world agent",
|
||||||
|
"url": "http://localhost:9999/",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"defaultInputModes": ["text"],
|
||||||
|
"defaultOutputModes": ["text"],
|
||||||
|
"capabilities": {
|
||||||
|
"streaming": true
|
||||||
|
},
|
||||||
|
"skills": [
|
||||||
|
{
|
||||||
|
"id": "hello_world",
|
||||||
|
"name": "Returns hello world",
|
||||||
|
"description": "just returns hello world",
|
||||||
|
"tags": ["hello world"],
|
||||||
|
"examples": ["hi", "hello world"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"litellm_params": {
|
||||||
|
"make_public": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the user ID from the API key auth
|
||||||
|
created_by = user_api_key_dict.user_id or "unknown"
|
||||||
|
|
||||||
|
# check for naming conflicts
|
||||||
|
existing_agent = AGENT_REGISTRY.get_agent_by_name(
|
||||||
|
agent_name=request.get("agent_name") # type: ignore
|
||||||
|
)
|
||||||
|
if existing_agent is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Agent with name {request.get('agent_name')} already exists",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await AGENT_REGISTRY.add_agent_to_db(
|
||||||
|
agent=request, prisma_client=prisma_client, created_by=created_by
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_name = result.get("agent_name", "Unknown")
|
||||||
|
agent_id = result.get("agent_id", "Unknown")
|
||||||
|
|
||||||
|
# Also register in memory
|
||||||
|
try:
|
||||||
|
AGENT_REGISTRY.register_agent(agent_config=request)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"Successfully registered agent '{agent_name}' (ID: {agent_id}) in memory"
|
||||||
|
)
|
||||||
|
except Exception as reg_error:
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
f"Failed to register agent '{agent_name}' (ID: {agent_id}) in memory: {reg_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResponse(**result)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error adding agent to db: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/v1/agents/{agent_id}",
|
||||||
|
tags=["[beta] Agents"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=AgentResponse,
|
||||||
|
)
|
||||||
|
async def get_agent_by_id(agent_id: str):
|
||||||
|
"""
|
||||||
|
Get a specific agent by ID
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```bash
|
||||||
|
curl -X GET "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||||
|
-H "Authorization: Bearer <your_api_key>"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||||
|
if agent is None:
|
||||||
|
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||||
|
where={"agent_id": agent_id}
|
||||||
|
)
|
||||||
|
if agent is not None:
|
||||||
|
agent = dict(agent)
|
||||||
|
|
||||||
|
if agent is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResponse(**agent)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error getting agent from db: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.put(
|
||||||
|
"/v1/agents/{agent_id}",
|
||||||
|
tags=["[beta] Agents"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=AgentResponse,
|
||||||
|
)
|
||||||
|
async def update_agent(
|
||||||
|
agent_id: str,
|
||||||
|
request: AgentConfig,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update an existing agent
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```bash
|
||||||
|
curl -X PUT "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||||
|
-H "Authorization: Bearer <your_api_key>" \\
|
||||||
|
-H "Content-Type: application/json" \\
|
||||||
|
-d '{
|
||||||
|
"agent": {
|
||||||
|
"agent_name": "updated-agent",
|
||||||
|
"agent_card_params": {
|
||||||
|
"protocolVersion": "1.0",
|
||||||
|
"name": "Updated Agent",
|
||||||
|
"description": "Updated description",
|
||||||
|
"url": "http://localhost:9999/",
|
||||||
|
"version": "1.1.0",
|
||||||
|
"defaultInputModes": ["text"],
|
||||||
|
"defaultOutputModes": ["text"],
|
||||||
|
"capabilities": {
|
||||||
|
"streaming": true
|
||||||
|
},
|
||||||
|
"skills": []
|
||||||
|
},
|
||||||
|
"litellm_params": {
|
||||||
|
"make_public": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if agent exists
|
||||||
|
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||||
|
where={"agent_id": agent_id}
|
||||||
|
)
|
||||||
|
if existing_agent is not None:
|
||||||
|
existing_agent = dict(existing_agent)
|
||||||
|
|
||||||
|
if existing_agent is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the user ID from the API key auth
|
||||||
|
updated_by = user_api_key_dict.user_id or "unknown"
|
||||||
|
|
||||||
|
result = await AGENT_REGISTRY.update_agent_in_db(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent=request,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
updated_by=updated_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
# deregister in memory
|
||||||
|
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||||
|
# register in memory
|
||||||
|
AGENT_REGISTRY.register_agent(agent_config=request)
|
||||||
|
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
f"Successfully updated agent '{existing_agent.get('agent_name')}' (ID: {agent_id}) in memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResponse(**result)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error updating agent: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/v1/agents/{agent_id}",
|
||||||
|
tags=["Agents"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def delete_agent(agent_id: str):
|
||||||
|
"""
|
||||||
|
Delete an agent
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```bash
|
||||||
|
curl -X DELETE "http://localhost:4000/agents/123e4567-e89b-12d3-a456-426614174000" \\
|
||||||
|
-H "Authorization: Bearer <your_api_key>"
|
||||||
|
```
|
||||||
|
|
||||||
|
Example Response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"message": "Agent 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if agent exists
|
||||||
|
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||||
|
where={"agent_id": agent_id}
|
||||||
|
)
|
||||||
|
if existing_agent is not None:
|
||||||
|
existing_agent = dict[Any, Any](existing_agent)
|
||||||
|
|
||||||
|
if existing_agent is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Agent with ID {agent_id} not found in DB."
|
||||||
|
)
|
||||||
|
|
||||||
|
await AGENT_REGISTRY.delete_agent_from_db(
|
||||||
|
agent_id=agent_id, prisma_client=prisma_client
|
||||||
|
)
|
||||||
|
|
||||||
|
AGENT_REGISTRY.deregister_agent(agent_name=existing_agent.get("agent_name")) # type: ignore
|
||||||
|
|
||||||
|
return {"message": f"Agent {agent_id} deleted successfully"}
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(f"Error deleting agent: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -40,6 +40,7 @@ from litellm.constants import (
|
|||||||
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
|
LITELLM_SETTINGS_SAFE_DB_OVERRIDES,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||||
|
from litellm.proxy.common_utils.callback_utils import normalize_callback_names
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ModelResponseStream,
|
ModelResponseStream,
|
||||||
@@ -48,8 +49,6 @@ from litellm.types.utils import (
|
|||||||
)
|
)
|
||||||
from litellm.utils import load_credentials_from_list
|
from litellm.utils import load_credentials_from_list
|
||||||
|
|
||||||
from litellm.proxy.common_utils.callback_utils import normalize_callback_names
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from opentelemetry.trace import Span as _Span
|
from opentelemetry.trace import Span as _Span
|
||||||
@@ -175,6 +174,8 @@ from litellm.proxy._experimental.mcp_server.tool_registry import (
|
|||||||
global_mcp_tool_registry,
|
global_mcp_tool_registry,
|
||||||
)
|
)
|
||||||
from litellm.proxy._types import *
|
from litellm.proxy._types import *
|
||||||
|
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||||
|
from litellm.proxy.agent_endpoints.endpoints import router as agent_endpoints_router
|
||||||
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
)
|
)
|
||||||
@@ -471,6 +472,8 @@ from fastapi.security import OAuth2PasswordBearer
|
|||||||
from fastapi.security.api_key import APIKeyHeader
|
from fastapi.security.api_key import APIKeyHeader
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
|
from litellm.types.agents import AgentConfig
|
||||||
|
|
||||||
# import enterprise folder
|
# import enterprise folder
|
||||||
enterprise_router = APIRouter()
|
enterprise_router = APIRouter()
|
||||||
try:
|
try:
|
||||||
@@ -1061,6 +1064,7 @@ callback_settings: dict = {}
|
|||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
master_key: Optional[str] = None
|
master_key: Optional[str] = None
|
||||||
|
config_agents: Optional[List[AgentConfig]] = None
|
||||||
otel_logging = False
|
otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
shared_aiohttp_session: Optional["ClientSession"] = (
|
shared_aiohttp_session: Optional["ClientSession"] = (
|
||||||
@@ -2585,6 +2589,11 @@ class ProxyConfig:
|
|||||||
if mcp_tools_config:
|
if mcp_tools_config:
|
||||||
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
global_mcp_tool_registry.load_tools_from_config(mcp_tools_config)
|
||||||
|
|
||||||
|
## AGENTS
|
||||||
|
agent_config = config.get("agent_list", None)
|
||||||
|
if agent_config:
|
||||||
|
global_agent_registry.load_agents_from_config(agent_config) # type: ignore
|
||||||
|
|
||||||
mcp_servers_config = config.get("mcp_servers", None)
|
mcp_servers_config = config.get("mcp_servers", None)
|
||||||
if mcp_servers_config:
|
if mcp_servers_config:
|
||||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||||
@@ -2821,7 +2830,9 @@ class ProxyConfig:
|
|||||||
for k, v in _litellm_params.items():
|
for k, v in _litellm_params.items():
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
# decrypt value - returns original value if decryption fails or no key is set
|
# decrypt value - returns original value if decryption fails or no key is set
|
||||||
_value = decrypt_value_helper(value=v, key=k, return_original_value=True)
|
_value = decrypt_value_helper(
|
||||||
|
value=v, key=k, return_original_value=True
|
||||||
|
)
|
||||||
_litellm_params[k] = _value
|
_litellm_params[k] = _value
|
||||||
_litellm_params = LiteLLM_Params(**_litellm_params)
|
_litellm_params = LiteLLM_Params(**_litellm_params)
|
||||||
|
|
||||||
@@ -3414,6 +3425,9 @@ class ProxyConfig:
|
|||||||
if self._should_load_db_object(object_type="mcp"):
|
if self._should_load_db_object(object_type="mcp"):
|
||||||
await self._init_mcp_servers_in_db()
|
await self._init_mcp_servers_in_db()
|
||||||
|
|
||||||
|
if self._should_load_db_object(object_type="agents"):
|
||||||
|
await self._init_agents_in_db(prisma_client=prisma_client)
|
||||||
|
|
||||||
if self._should_load_db_object(object_type="pass_through_endpoints"):
|
if self._should_load_db_object(object_type="pass_through_endpoints"):
|
||||||
await self._init_pass_through_endpoints_in_db()
|
await self._init_pass_through_endpoints_in_db()
|
||||||
|
|
||||||
@@ -3682,6 +3696,25 @@ class ProxyConfig:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _init_agents_in_db(self, prisma_client: PrismaClient):
|
||||||
|
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||||
|
global_agent_registry as AGENT_REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
db_agents = await AGENT_REGISTRY.get_all_agents_from_db(
|
||||||
|
prisma_client=prisma_client
|
||||||
|
)
|
||||||
|
AGENT_REGISTRY.load_agents_from_db_and_config(
|
||||||
|
db_agents=db_agents, agent_config=config_agents
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_proxy_logger.exception(
|
||||||
|
"litellm.proxy.proxy_server.py::ProxyConfig:_init_agents_in_db - {}".format(
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
async def _init_search_tools_in_db(self, prisma_client: PrismaClient):
|
async def _init_search_tools_in_db(self, prisma_client: PrismaClient):
|
||||||
"""
|
"""
|
||||||
Initialize search tools from database into the router on startup.
|
Initialize search tools from database into the router on startup.
|
||||||
@@ -8963,7 +8996,9 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
|||||||
if isinstance(
|
if isinstance(
|
||||||
config["litellm_settings"]["success_callback"], list
|
config["litellm_settings"]["success_callback"], list
|
||||||
) and isinstance(updated_litellm_settings["success_callback"], list):
|
) and isinstance(updated_litellm_settings["success_callback"], list):
|
||||||
updated_success_callbacks_normalized = normalize_callback_names(updated_litellm_settings["success_callback"])
|
updated_success_callbacks_normalized = normalize_callback_names(
|
||||||
|
updated_litellm_settings["success_callback"]
|
||||||
|
)
|
||||||
combined_success_callback = (
|
combined_success_callback = (
|
||||||
config["litellm_settings"]["success_callback"]
|
config["litellm_settings"]["success_callback"]
|
||||||
+ updated_success_callbacks_normalized
|
+ updated_success_callbacks_normalized
|
||||||
@@ -10097,6 +10132,7 @@ app.include_router(cache_settings_router)
|
|||||||
app.include_router(user_agent_analytics_router)
|
app.include_router(user_agent_analytics_router)
|
||||||
app.include_router(enterprise_router)
|
app.include_router(enterprise_router)
|
||||||
app.include_router(ui_discovery_endpoints_router)
|
app.include_router(ui_discovery_endpoints_router)
|
||||||
|
app.include_router(agent_endpoints_router)
|
||||||
########################################################
|
########################################################
|
||||||
# MCP Server
|
# MCP Server
|
||||||
########################################################
|
########################################################
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from litellm.proxy.public_endpoints.provider_create_metadata import (
|
|||||||
get_provider_create_metadata,
|
get_provider_create_metadata,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
|
from litellm.types.agents import AgentCard
|
||||||
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
from litellm.types.proxy.management_endpoints.model_management_endpoints import (
|
||||||
ModelGroupInfoProxy,
|
ModelGroupInfoProxy,
|
||||||
)
|
)
|
||||||
@@ -45,6 +46,19 @@ async def public_model_hub():
|
|||||||
return model_groups
|
return model_groups
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/public/agent_hub",
|
||||||
|
tags=["[beta] Agents", "public"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
response_model=List[AgentCard],
|
||||||
|
)
|
||||||
|
async def get_agents():
|
||||||
|
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
||||||
|
|
||||||
|
agents = global_agent_registry.get_public_agent_list()
|
||||||
|
return [agent.get("agent_card_params") for agent in agents]
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/public/model_hub/info",
|
"/public/model_hub/info",
|
||||||
tags=["public", "model management"],
|
tags=["public", "model management"],
|
||||||
|
|||||||
@@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable {
|
|||||||
updated_by String
|
updated_by String
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Agents on proxy
|
||||||
|
model LiteLLM_AgentsTable {
|
||||||
|
agent_id String @id @default(uuid())
|
||||||
|
agent_name String @unique
|
||||||
|
litellm_params Json?
|
||||||
|
agent_card_params Json
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
model LiteLLM_OrganizationTable {
|
model LiteLLM_OrganizationTable {
|
||||||
organization_id String @id @default(uuid())
|
organization_id String @id @default(uuid())
|
||||||
organization_alias String
|
organization_alias String
|
||||||
|
|||||||
186
litellm/types/agents.py
Normal file
186
litellm/types/agents.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
# AgentProvider
|
||||||
|
class AgentProvider(TypedDict, total=False):
|
||||||
|
"""Represents the service provider of an agent."""
|
||||||
|
|
||||||
|
organization: str # required
|
||||||
|
url: str # required
|
||||||
|
|
||||||
|
|
||||||
|
# AgentExtension
|
||||||
|
class AgentExtension(TypedDict, total=False):
|
||||||
|
"""A declaration of a protocol extension supported by an Agent."""
|
||||||
|
|
||||||
|
uri: str # required
|
||||||
|
description: Optional[str]
|
||||||
|
required: Optional[bool]
|
||||||
|
params: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
# AgentCapabilities
|
||||||
|
class AgentCapabilities(TypedDict, total=False):
|
||||||
|
"""Defines optional capabilities supported by an agent."""
|
||||||
|
|
||||||
|
streaming: Optional[bool]
|
||||||
|
pushNotifications: Optional[bool]
|
||||||
|
stateTransitionHistory: Optional[bool]
|
||||||
|
extensions: Optional[List[AgentExtension]]
|
||||||
|
|
||||||
|
|
||||||
|
# SecurityScheme types
|
||||||
|
class SecuritySchemeBase(TypedDict, total=False):
|
||||||
|
"""Base properties shared by all security scheme objects."""
|
||||||
|
|
||||||
|
description: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeySecurityScheme(SecuritySchemeBase):
|
||||||
|
"""Defines a security scheme using an API key."""
|
||||||
|
|
||||||
|
type: Literal["apiKey"]
|
||||||
|
in_: Literal["query", "header", "cookie"] # using in_ to avoid Python keyword
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPAuthSecurityScheme(SecuritySchemeBase):
|
||||||
|
"""Defines a security scheme using HTTP authentication."""
|
||||||
|
|
||||||
|
type: Literal["http"]
|
||||||
|
scheme: str
|
||||||
|
bearerFormat: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class MutualTLSSecurityScheme(SecuritySchemeBase):
|
||||||
|
"""Defines a security scheme using mTLS authentication."""
|
||||||
|
|
||||||
|
type: Literal["mutualTLS"]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthFlows(TypedDict, total=False):
|
||||||
|
"""Defines the configuration for the supported OAuth 2.0 flows."""
|
||||||
|
|
||||||
|
authorizationCode: Optional[Dict[str, Any]]
|
||||||
|
clientCredentials: Optional[Dict[str, Any]]
|
||||||
|
implicit: Optional[Dict[str, Any]]
|
||||||
|
password: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2SecurityScheme(SecuritySchemeBase):
|
||||||
|
"""Defines a security scheme using OAuth 2.0."""
|
||||||
|
|
||||||
|
type: Literal["oauth2"]
|
||||||
|
flows: OAuthFlows
|
||||||
|
oauth2MetadataUrl: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenIdConnectSecurityScheme(SecuritySchemeBase):
|
||||||
|
"""Defines a security scheme using OpenID Connect."""
|
||||||
|
|
||||||
|
type: Literal["openIdConnect"]
|
||||||
|
openIdConnectUrl: str
|
||||||
|
|
||||||
|
|
||||||
|
# Union of all security schemes
|
||||||
|
SecurityScheme = Union[
|
||||||
|
APIKeySecurityScheme,
|
||||||
|
HTTPAuthSecurityScheme,
|
||||||
|
OAuth2SecurityScheme,
|
||||||
|
OpenIdConnectSecurityScheme,
|
||||||
|
MutualTLSSecurityScheme,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# AgentSkill
|
||||||
|
class AgentSkill(TypedDict, total=False):
|
||||||
|
"""Represents a distinct capability or function that an agent can perform."""
|
||||||
|
|
||||||
|
id: str # required
|
||||||
|
name: str # required
|
||||||
|
description: str # required
|
||||||
|
tags: List[str] # required
|
||||||
|
examples: Optional[List[str]]
|
||||||
|
inputModes: Optional[List[str]]
|
||||||
|
outputModes: Optional[List[str]]
|
||||||
|
security: Optional[List[Dict[str, List[str]]]]
|
||||||
|
|
||||||
|
|
||||||
|
# AgentInterface
|
||||||
|
class AgentInterface(TypedDict, total=False):
|
||||||
|
"""Declares a combination of a target URL and a transport protocol."""
|
||||||
|
|
||||||
|
url: str # required
|
||||||
|
transport: str # required (TransportProtocol | string)
|
||||||
|
|
||||||
|
|
||||||
|
# AgentCardSignature
|
||||||
|
class AgentCardSignature(TypedDict, total=False):
|
||||||
|
"""Represents a JWS signature of an AgentCard."""
|
||||||
|
|
||||||
|
protected: str # required
|
||||||
|
signature: str # required
|
||||||
|
header: Optional[Dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
# AgentCard
|
||||||
|
class AgentCard(TypedDict, total=False):
|
||||||
|
"""
|
||||||
|
The AgentCard is a self-describing manifest for an agent.
|
||||||
|
It provides essential metadata including the agent's identity, capabilities,
|
||||||
|
skills, supported communication methods, and security requirements.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Required fields
|
||||||
|
protocolVersion: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
url: str
|
||||||
|
version: str
|
||||||
|
capabilities: AgentCapabilities
|
||||||
|
defaultInputModes: List[str]
|
||||||
|
defaultOutputModes: List[str]
|
||||||
|
skills: List[AgentSkill]
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
preferredTransport: Optional[str]
|
||||||
|
additionalInterfaces: Optional[List[AgentInterface]]
|
||||||
|
iconUrl: Optional[str]
|
||||||
|
provider: Optional[AgentProvider]
|
||||||
|
documentationUrl: Optional[str]
|
||||||
|
securitySchemes: Optional[Dict[str, SecurityScheme]]
|
||||||
|
security: Optional[List[Dict[str, List[str]]]]
|
||||||
|
supportsAuthenticatedExtendedCard: Optional[bool]
|
||||||
|
signatures: Optional[List[AgentCardSignature]]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentLitellmParams(TypedDict):
|
||||||
|
make_public: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(TypedDict, total=False):
|
||||||
|
agent_name: Required[str]
|
||||||
|
agent_card_params: Required[AgentCard]
|
||||||
|
litellm_params: AgentLitellmParams
|
||||||
|
|
||||||
|
|
||||||
|
# Request/Response models for CRUD endpoints
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
agent_name: str
|
||||||
|
litellm_params: Optional[Dict[str, Any]] = None
|
||||||
|
agent_card_params: Dict[str, Any]
|
||||||
|
created_at: Optional[datetime] = None
|
||||||
|
updated_at: Optional[datetime] = None
|
||||||
|
created_by: Optional[str] = None
|
||||||
|
updated_by: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListAgentsResponse(BaseModel):
|
||||||
|
agents: List[AgentResponse]
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"ignore": [],
|
"ignore": [],
|
||||||
"exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"],
|
"exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py", "litellm/proxy/_types.py"],
|
||||||
"reportMissingImports": false,
|
"reportMissingImports": false,
|
||||||
"reportPrivateImportUsage": false
|
"reportPrivateImportUsage": false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,6 +54,19 @@ model LiteLLM_ProxyModelTable {
|
|||||||
updated_by String
|
updated_by String
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Agents on proxy
|
||||||
|
model LiteLLM_AgentsTable {
|
||||||
|
agent_id String @id @default(uuid())
|
||||||
|
agent_name String @unique
|
||||||
|
litellm_params Json?
|
||||||
|
agent_card_params Json
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
}
|
||||||
|
|
||||||
model LiteLLM_OrganizationTable {
|
model LiteLLM_OrganizationTable {
|
||||||
organization_id String @id @default(uuid())
|
organization_id String @id @default(uuid())
|
||||||
organization_alias String
|
organization_alias String
|
||||||
|
|||||||
Reference in New Issue
Block a user