feat: add backend support for OAuth2 auth_type registration via UI (#17006)

This commit is contained in:
YutaSaito
2025-11-24 14:52:18 +09:00
committed by GitHub
parent f0b10b854b
commit b72b49757e
9 changed files with 939 additions and 230 deletions

View File

@@ -14,6 +14,7 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.types.mcp_server.mcp_server_manager import MCPServer
router = APIRouter(
tags=["mcp"],
@@ -122,6 +123,163 @@ def decode_state_hash(encrypted_state: str) -> dict:
return state_data
async def authorize_with_server(
request: Request,
mcp_server: MCPServer,
client_id: str,
redirect_uri: str,
state: str = "",
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
if mcp_server.auth_type != "oauth2":
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
parsed = urlparse(redirect_uri)
base_url = urlunparse(parsed._replace(query=""))
request_base_url = get_request_base_url(request)
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
)
params = {
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"redirect_uri": f"{request_base_url}/callback",
"state": encoded_state,
"response_type": response_type or "code",
}
if scope:
params["scope"] = scope
elif mcp_server.scopes:
params["scope"] = " ".join(mcp_server.scopes)
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
async def exchange_token_with_server(
request: Request,
mcp_server: MCPServer,
grant_type: str,
code: Optional[str],
redirect_uri: Optional[str],
client_id: str,
client_secret: Optional[str],
code_verifier: Optional[str],
):
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="Unsupported grant_type")
if mcp_server.token_url is None:
raise HTTPException(status_code=400, detail="MCP server token url is not set")
proxy_base_url = get_request_base_url(request)
token_data = {
"grant_type": "authorization_code",
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"client_secret": mcp_server.client_secret
if mcp_server.client_secret
else client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}
if code_verifier:
token_data["code_verifier"] = code_verifier
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data=token_data,
)
response.raise_for_status()
token_response = response.json()
access_token = token_response["access_token"]
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]
return JSONResponse(result)
async def register_client_with_server(
request: Request,
mcp_server: MCPServer,
client_name: str,
grant_types: Optional[list],
response_types: Optional[list],
token_endpoint_auth_method: Optional[str],
fallback_client_id: Optional[str] = None,
):
request_base_url = get_request_base_url(request)
dummy_return = {
"client_id": fallback_client_id or mcp_server.server_name,
"client_secret": "dummy",
"redirect_uris": [f"{request_base_url}/callback"],
}
if mcp_server.client_id and mcp_server.client_secret:
return dummy_return
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
if mcp_server.registration_url is None:
return dummy_return
register_data = {
"client_name": client_name,
"redirect_uris": [f"{request_base_url}/callback"],
"grant_types": grant_types or [],
"response_types": response_types or [],
"token_endpoint_auth_method": token_endpoint_auth_method or "",
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Oauth2Register
)
response = await async_client.post(
mcp_server.registration_url,
headers=headers,
json=register_data,
)
response.raise_for_status()
token_response = response.json()
return JSONResponse(token_response)
@router.get("/{mcp_server_name}/authorize")
@router.get("/authorize")
async def authorize(
@@ -140,53 +298,21 @@ async def authorize(
global_mcp_server_manager,
)
if mcp_server_name:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
else:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(client_id)
lookup_name = mcp_server_name or client_id
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
if mcp_server.auth_type != "oauth2":
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
# Parse it to remove any existing query
parsed = urlparse(redirect_uri)
base_url = urlunparse(parsed._replace(query=""))
# Get the correct base URL considering X-Forwarded-* headers
request_base_url = get_request_base_url(request)
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
return await authorize_with_server(
request=request,
mcp_server=mcp_server,
client_id=client_id,
redirect_uri=redirect_uri,
state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
response_type=response_type,
scope=scope,
)
# Build params for upstream OAuth provider
params = {
"client_id": client_id if client_id else mcp_server.client_id,
"redirect_uri": f"{request_base_url}/callback",
"state": encoded_state,
"response_type": response_type or "code",
}
if scope:
params["scope"] = scope
elif mcp_server.scopes:
params["scope"] = " ".join(mcp_server.scopes)
# Forward PKCE parameters if present
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
@router.post("/{mcp_server_name}/token")
@@ -214,64 +340,21 @@ async def token_endpoint(
global_mcp_server_manager,
)
if mcp_server_name:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
else:
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(client_id)
lookup_name = mcp_server_name or client_id
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="Unsupported grant_type")
if mcp_server.token_url is None:
raise HTTPException(status_code=400, detail="MCP server token url is not set")
# Get the correct base URL considering X-Forwarded-* headers
proxy_base_url = get_request_base_url(request)
# Build token request data
token_data = {
"grant_type": "authorization_code",
"client_id": client_id if client_id else mcp_server.client_id,
"client_secret": client_secret if client_secret else mcp_server.client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}
# Forward PKCE code_verifier if present
if code_verifier:
token_data["code_verifier"] = code_verifier
# Exchange code for real OAuth token
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data=token_data,
return await exchange_token_with_server(
request=request,
mcp_server=mcp_server,
grant_type=grant_type,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
response.raise_for_status()
token_response = response.json()
access_token = token_response["access_token"]
# Return to client in expected OAuth 2 format
# Only include fields that have values
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}
# Add optional fields only if they exist
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]
return JSONResponse(result)
@router.get("/callback")
async def callback(code: str, state: str):
@@ -391,44 +474,12 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
if mcp_server is None:
return dummy_return
if mcp_server.client_id and mcp_server.client_secret:
return {
"client_id": mcp_server.client_id,
"client_secret": mcp_server.client_secret,
"redirect_uris": [f"{request_base_url}/callback"],
}
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
if mcp_server.registration_url is None:
return dummy_return
register_data = {
"client_name": data.get("client_name", ""),
"redirect_uris": [f"{request_base_url}/callback"],
"grant_types": data.get("grant_types", []),
"response_types": data.get("response_types", []),
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", ""),
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Oauth2Register
return await register_client_with_server(
request=request,
mcp_server=mcp_server,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=mcp_server_name,
)
response = await async_client.post(
mcp_server.registration_url,
headers=headers,
json=register_data,
)
response.raise_for_status()
token_response = response.json()
return JSONResponse(token_response)

View File

@@ -395,12 +395,12 @@ class MCPServerManager:
)
# Update tool name to server name mapping (for both prefixed and base names)
self.tool_name_to_mcp_server_name_mapping[base_tool_name] = (
server_prefix
)
self.tool_name_to_mcp_server_name_mapping[prefixed_tool_name] = (
server_prefix
)
self.tool_name_to_mcp_server_name_mapping[
base_tool_name
] = server_prefix
self.tool_name_to_mcp_server_name_mapping[
prefixed_tool_name
] = server_prefix
registered_count += 1
verbose_logger.debug(
@@ -432,73 +432,127 @@ class MCPServerManager:
f"Server ID {mcp_server.server_id} not found in registry"
)
def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
async def build_mcp_server_from_table(
self,
mcp_server: LiteLLM_MCPServerTable,
*,
credentials_are_encrypted: bool = True,
) -> MCPServer:
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None))
static_headers_dict = _deserialize_json_dict(
getattr(mcp_server, "static_headers", None)
)
credentials_dict = _deserialize_json_dict(
getattr(mcp_server, "credentials", None)
)
encrypted_auth_value: Optional[str] = None
encrypted_client_id: Optional[str] = None
encrypted_client_secret: Optional[str] = None
if credentials_dict:
encrypted_auth_value = credentials_dict.get("auth_value")
encrypted_client_id = credentials_dict.get("client_id")
encrypted_client_secret = credentials_dict.get("client_secret")
auth_value: Optional[str] = None
if encrypted_auth_value:
if credentials_are_encrypted:
auth_value = decrypt_value_helper(
value=encrypted_auth_value,
key="auth_value",
exception_type="debug",
return_original_value=True,
)
else:
auth_value = encrypted_auth_value
client_id_value: Optional[str] = None
if encrypted_client_id:
if credentials_are_encrypted:
client_id_value = decrypt_value_helper(
value=encrypted_client_id,
key="client_id",
exception_type="debug",
return_original_value=True,
)
else:
client_id_value = encrypted_client_id
client_secret_value: Optional[str] = None
if encrypted_client_secret:
if credentials_are_encrypted:
client_secret_value = decrypt_value_helper(
value=encrypted_client_secret,
key="client_secret",
exception_type="debug",
return_original_value=True,
)
else:
client_secret_value = encrypted_client_secret
scopes: Optional[List[str]] = None
if credentials_dict:
scopes_value = credentials_dict.get("scopes")
if scopes_value is not None:
scopes = self._extract_scopes(scopes_value)
name_for_prefix = (
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
)
mcp_info: MCPInfo = _mcp_info.copy()
if "server_name" not in mcp_info:
mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id
if "description" not in mcp_info and mcp_server.description:
mcp_info["description"] = mcp_server.description
auth_type = cast(MCPAuthType, mcp_server.auth_type)
if mcp_server.url and auth_type == MCPAuth.oauth2:
mcp_oauth_metadata = await self._descovery_metadata(
server_url=mcp_server.url,
)
else:
mcp_oauth_metadata = None
resolved_scopes = scopes or (
mcp_oauth_metadata.scopes if mcp_oauth_metadata else None
)
new_server = MCPServer(
server_id=mcp_server.server_id,
name=name_for_prefix,
alias=getattr(mcp_server, "alias", None),
server_name=getattr(mcp_server, "server_name", None),
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
auth_type=auth_type,
authentication_token=auth_value,
mcp_info=mcp_info,
extra_headers=getattr(mcp_server, "extra_headers", None),
static_headers=static_headers_dict,
client_id=client_id_value or getattr(mcp_server, "client_id", None),
client_secret=client_secret_value
or getattr(mcp_server, "client_secret", None),
scopes=resolved_scopes,
authorization_url=getattr(mcp_oauth_metadata, "authorization_url", None),
token_url=getattr(mcp_oauth_metadata, "token_url", None),
registration_url=getattr(mcp_oauth_metadata, "registration_url", None),
command=getattr(mcp_server, "command", None),
args=getattr(mcp_server, "args", None) or [],
env=env_dict,
access_groups=getattr(mcp_server, "mcp_access_groups", None),
allowed_tools=getattr(mcp_server, "allowed_tools", None),
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
)
return new_server
async def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
try:
if mcp_server.server_id not in self.get_registry():
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
# Use helper to deserialize dictionary
# Safely access env field which may not exist on Prisma model objects
env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None))
static_headers_dict = _deserialize_json_dict(
getattr(mcp_server, "static_headers", None)
)
credentials_dict = _deserialize_json_dict(
getattr(mcp_server, "credentials", None)
)
encrypted_auth_value: Optional[str] = None
if credentials_dict:
encrypted_auth_value = credentials_dict.get("auth_value")
auth_value: Optional[str] = None
if encrypted_auth_value:
auth_value = decrypt_value_helper(
value=encrypted_auth_value,
key="auth_value",
)
# Use alias for name if present, else server_name
name_for_prefix = (
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
)
# Preserve all custom fields from database while setting defaults for core fields
mcp_info: MCPInfo = _mcp_info.copy()
# Set default values for core fields if not present
if "server_name" not in mcp_info:
mcp_info["server_name"] = (
mcp_server.server_name or mcp_server.server_id
)
if "description" not in mcp_info and mcp_server.description:
mcp_info["description"] = mcp_server.description
new_server = MCPServer(
server_id=mcp_server.server_id,
name=name_for_prefix,
alias=getattr(mcp_server, "alias", None),
server_name=getattr(mcp_server, "server_name", None),
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
auth_type=cast(MCPAuthType, mcp_server.auth_type),
authentication_token=auth_value,
mcp_info=mcp_info,
extra_headers=getattr(mcp_server, "extra_headers", None),
static_headers=static_headers_dict,
# oauth specific fields
client_id=getattr(mcp_server, "client_id", None),
client_secret=getattr(mcp_server, "client_secret", None),
scopes=getattr(mcp_server, "scopes", None),
authorization_url=getattr(mcp_server, "authorization_url", None),
token_url=getattr(mcp_server, "token_url", None),
registration_url=getattr(mcp_server, "registration_url", None),
# Stdio-specific fields
command=getattr(mcp_server, "command", None),
args=getattr(mcp_server, "args", None) or [],
env=env_dict,
access_groups=getattr(mcp_server, "mcp_access_groups", None),
allowed_tools=getattr(mcp_server, "allowed_tools", None),
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
)
if mcp_server.server_id not in self.registry:
new_server = await self.build_mcp_server_from_table(mcp_server)
self.registry[mcp_server.server_id] = new_server
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")
verbose_logger.debug(f"Added MCP Server: {new_server.name}")
except Exception as e:
verbose_logger.debug(f"Failed to add MCP server: {str(e)}")

View File

@@ -293,7 +293,11 @@ if MCP_AVAILABLE:
NewMCPServerRequest,
)
async def _execute_with_mcp_client(request: NewMCPServerRequest, operation):
async def _execute_with_mcp_client(
request: NewMCPServerRequest,
operation,
oauth2_headers: Optional[Dict[str, str]] = None,
):
"""
Common helper to create MCP client, execute operation, and ensure proper cleanup.
@@ -315,6 +319,7 @@ if MCP_AVAILABLE:
mcp_info=request.mcp_info,
),
mcp_auth_header=None,
extra_headers=oauth2_headers,
)
return await operation(client)
@@ -342,12 +347,19 @@ if MCP_AVAILABLE:
@router.post("/test/tools/list")
async def test_tools_list(
request: NewMCPServerRequest,
request: Request,
new_mcp_server_request: NewMCPServerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Preview tools available from MCP server before adding it
"""
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
headers = request.headers
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers)
async def _list_tools_operation(client):
async def _list_tools_session_operation(session):
@@ -366,4 +378,6 @@ if MCP_AVAILABLE:
"message": "Successfully retrieved tools",
}
return await _execute_with_mcp_client(request, _list_tools_operation)
return await _execute_with_mcp_client(
new_mcp_server_request, _list_tools_operation, oauth2_headers
)

View File

@@ -14,13 +14,24 @@ Endpoints here:
"""
import importlib
from datetime import datetime
from typing import Iterable, List, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Dict, Iterable, List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Response, status
from fastapi import (
APIRouter,
Depends,
Form,
Header,
HTTPException,
Request,
Response,
status,
)
from fastapi.responses import JSONResponse
import litellm
from litellm._uuid import uuid
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
from litellm.proxy._experimental.mcp_server.utils import (
@@ -29,6 +40,7 @@ from litellm.proxy._experimental.mcp_server.utils import (
router = APIRouter(prefix="/v1/mcp", tags=["mcp"])
MCP_AVAILABLE: bool = True
TEMPORARY_MCP_SERVER_TTL_SECONDS = 300
try:
importlib.import_module("mcp")
except ImportError as e:
@@ -43,9 +55,15 @@ if MCP_AVAILABLE:
get_mcp_server,
update_mcp_server,
)
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
authorize_with_server,
exchange_token_with_server,
register_client_with_server,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LitellmUserRoles,
@@ -58,6 +76,47 @@ if MCP_AVAILABLE:
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.types.mcp import MCPCredentials
from litellm.types.mcp_server.mcp_server_manager import MCPServer
@dataclass
class _TemporaryMCPServerEntry:
server: MCPServer
expires_at: datetime
_temporary_mcp_servers: Dict[str, _TemporaryMCPServerEntry] = {}
def _prune_expired_temporary_mcp_servers() -> None:
if not _temporary_mcp_servers:
return
now = datetime.utcnow()
expired_ids = [
server_id
for server_id, entry in _temporary_mcp_servers.items()
if entry.expires_at <= now
]
for server_id in expired_ids:
_temporary_mcp_servers.pop(server_id, None)
def _cache_temporary_mcp_server(server: MCPServer, ttl_seconds: int) -> MCPServer:
ttl_seconds = max(1, ttl_seconds)
_prune_expired_temporary_mcp_servers()
expires_at = datetime.utcnow() + timedelta(seconds=ttl_seconds)
_temporary_mcp_servers[server.server_id] = _TemporaryMCPServerEntry(
server=server,
expires_at=expires_at,
)
return server
def get_cached_temporary_mcp_server(
server_id: str,
) -> Optional[MCPServer]:
_prune_expired_temporary_mcp_servers()
entry = _temporary_mcp_servers.get(server_id)
if entry is None:
return None
return entry.server
def _redact_mcp_credentials(
mcp_server: LiteLLM_MCPServerTable,
@@ -79,6 +138,75 @@ if MCP_AVAILABLE:
) -> List[LiteLLM_MCPServerTable]:
return [_redact_mcp_credentials(server) for server in mcp_servers]
def _inherit_credentials_from_existing_server(
payload: NewMCPServerRequest,
) -> NewMCPServerRequest:
if not payload.server_id or payload.credentials:
return payload
existing_server = global_mcp_server_manager.get_mcp_server_by_id(
payload.server_id
)
if existing_server is None:
return payload
inherited_credentials: MCPCredentials = {}
if existing_server.authentication_token:
inherited_credentials["auth_value"] = existing_server.authentication_token
if existing_server.client_id:
inherited_credentials["client_id"] = existing_server.client_id
if existing_server.client_secret:
inherited_credentials["client_secret"] = existing_server.client_secret
if existing_server.scopes:
inherited_credentials["scopes"] = existing_server.scopes
if not inherited_credentials:
return payload
try:
return payload.model_copy(update={"credentials": inherited_credentials})
except AttributeError:
pass
payload_dict: Dict[str, Any]
try:
payload_dict = payload.model_dump() # type: ignore[attr-defined]
except AttributeError:
payload_dict = payload.dict() # type: ignore[attr-defined]
payload_dict["credentials"] = inherited_credentials
return NewMCPServerRequest(**payload_dict)
def _build_temporary_mcp_server_record(
payload: NewMCPServerRequest,
created_by: Optional[str],
) -> LiteLLM_MCPServerTable:
now = datetime.utcnow()
server_id = payload.server_id or str(uuid.uuid4())
server_name = payload.server_name or payload.alias or server_id
return LiteLLM_MCPServerTable(
server_id=server_id,
server_name=server_name,
alias=payload.alias,
description=payload.description,
url=payload.url,
transport=payload.transport,
auth_type=payload.auth_type,
credentials=payload.credentials,
created_at=now,
updated_at=now,
created_by=created_by,
updated_by=created_by,
teams=[],
mcp_access_groups=payload.mcp_access_groups,
allowed_tools=payload.allowed_tools or [],
extra_headers=payload.extra_headers or [],
mcp_info=payload.mcp_info,
static_headers=payload.static_headers,
command=payload.command,
args=payload.args,
env=payload.env,
)
def get_prisma_client_or_throw(message: str):
from litellm.proxy.proxy_server import prisma_client
@@ -376,7 +504,7 @@ if MCP_AVAILABLE:
exists = does_mcp_server_exist(mcp_server_records, server_id)
if exists:
global_mcp_server_manager.add_update_server(mcp_server)
await global_mcp_server_manager.add_update_server(mcp_server)
return _redact_mcp_credentials(mcp_server)
else:
raise HTTPException(
@@ -450,7 +578,7 @@ if MCP_AVAILABLE:
payload,
touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
)
global_mcp_server_manager.add_update_server(new_mcp_server)
await global_mcp_server_manager.add_update_server(new_mcp_server)
# Ensure registry is up to date by reloading from database
await global_mcp_server_manager.reload_servers_from_database()
@@ -462,6 +590,151 @@ if MCP_AVAILABLE:
)
return _redact_mcp_credentials(new_mcp_server)
@router.post(
"/server/oauth/session",
description="Temporarily cache an MCP server in memory without writing to the database",
dependencies=[Depends(user_api_key_auth)],
status_code=status.HTTP_200_OK,
)
@management_endpoint_wrapper
async def add_session_mcp_server(
payload: NewMCPServerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
litellm_changed_by: Optional[str] = Header(
None,
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
),
):
"""
Cache MCP server info in memory for a short duration (~5 minutes).
This endpoint does not write to the database. If the same server_id is provided
again while the cache entry is active, it will refresh the cached data + TTL.
"""
# Validate and normalize payload fields (alias/server name rules)
validate_and_normalize_mcp_server_payload(payload)
# Restrict to proxy admins similar to the persistent create endpoint
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "User does not have permission to create temporary mcp servers. You can only create temporary mcp servers if you are a PROXY_ADMIN."
},
)
created_by = user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME
payload_with_credentials = _inherit_credentials_from_existing_server(payload)
temp_record = _build_temporary_mcp_server_record(
payload_with_credentials,
created_by,
)
try:
temporary_server = (
await global_mcp_server_manager.build_mcp_server_from_table(
temp_record,
credentials_are_encrypted=False,
)
)
_cache_temporary_mcp_server(
temporary_server,
ttl_seconds=TEMPORARY_MCP_SERVER_TTL_SECONDS,
)
except Exception as e:
verbose_proxy_logger.exception(
f"Error caching temporary mcp server: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={"error": f"Error caching temporary mcp server: {str(e)}"},
)
return _redact_mcp_credentials(temp_record)
def _get_cached_temporary_mcp_server_or_404(server_id: str) -> MCPServer:
server = get_cached_temporary_mcp_server(server_id)
if server is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"error": f"Temporary MCP server {server_id} not found"},
)
return server
@router.get(
"/server/oauth/{server_id}/authorize",
include_in_schema=False,
)
async def mcp_authorize(
request: Request,
server_id: str,
client_id: str,
redirect_uri: str,
state: str = "",
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
return await authorize_with_server(
request=request,
mcp_server=mcp_server,
client_id=client_id,
redirect_uri=redirect_uri,
state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
response_type=response_type,
scope=scope,
)
@router.post(
"/server/oauth/{server_id}/token",
include_in_schema=False,
)
async def mcp_token(
request: Request,
server_id: str,
grant_type: str = Form(...),
code: Optional[str] = Form(None),
redirect_uri: Optional[str] = Form(None),
client_id: str = Form(...),
client_secret: Optional[str] = Form(None),
code_verifier: Optional[str] = Form(None),
):
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
return await exchange_token_with_server(
request=request,
mcp_server=mcp_server,
grant_type=grant_type,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
@router.post(
"/server/oauth/{server_id}/register",
include_in_schema=False,
)
async def mcp_register(request: Request, server_id: str):
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
return await register_client_with_server(
request=request,
mcp_server=mcp_server,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=server_id,
)
@router.delete(
"/server/{server_id}",
description="Allows deleting mcp serves in the db",
@@ -586,7 +859,7 @@ if MCP_AVAILABLE:
"error": f"MCP Server not found, passed server_id={payload.server_id}"
},
)
global_mcp_server_manager.add_update_server(mcp_server_record_updated)
await global_mcp_server_manager.add_update_server(mcp_server_record_updated)
# Ensure registry is up to date by reloading from database
await global_mcp_server_manager.reload_servers_from_database()

View File

@@ -1048,7 +1048,7 @@ async def test_mcp_server_manager_config_integration_with_database():
)
# Test the add_update_server method (this tests our fix)
test_manager.add_update_server(db_server)
await test_manager.add_update_server(db_server)
# Verify the server was added with correct access_groups
registry = test_manager.get_registry()
@@ -1342,7 +1342,8 @@ async def test_mcp_server_manager_server_id_tool_prefixing():
)
def test_add_update_server_with_alias():
@pytest.mark.asyncio
async def test_add_update_server_with_alias():
"""
Test that add_update_server correctly handles servers with alias.
"""
@@ -1371,7 +1372,7 @@ def test_add_update_server_with_alias():
mock_mcp_server.token_url = None
# Add server to manager
test_manager.add_update_server(mock_mcp_server)
await test_manager.add_update_server(mock_mcp_server)
# Verify server was added with correct name (should use alias)
assert "test-server-123" in test_manager.registry
@@ -1381,7 +1382,8 @@ def test_add_update_server_with_alias():
assert added_server.server_name == "Test Server"
def test_add_update_server_without_alias():
@pytest.mark.asyncio
async def test_add_update_server_without_alias():
"""
Test that add_update_server correctly handles servers without alias.
"""
@@ -1410,7 +1412,7 @@ def test_add_update_server_without_alias():
mock_mcp_server.token_url = None
# Add server to manager
test_manager.add_update_server(mock_mcp_server)
await test_manager.add_update_server(mock_mcp_server)
# Verify server was added with correct name (should use server_name)
assert "test-server-123" in test_manager.registry
@@ -1420,7 +1422,8 @@ def test_add_update_server_without_alias():
assert added_server.server_name == "Test Server"
def test_add_update_server_fallback_to_server_id():
@pytest.mark.asyncio
async def test_add_update_server_fallback_to_server_id():
"""
Test that add_update_server falls back to server_id when neither alias nor server_name are available.
"""
@@ -1449,7 +1452,7 @@ def test_add_update_server_fallback_to_server_id():
mock_mcp_server.token_url = None
# Add server to manager
test_manager.add_update_server(mock_mcp_server)
await test_manager.add_update_server(mock_mcp_server)
# Verify server was added with correct name (should use server_id)
assert "test-server-123" in test_manager.registry

View File

@@ -310,8 +310,8 @@ async def test_register_client_returns_existing_server_credentials():
global_mcp_server_manager.registry.clear()
assert result == {
"client_id": "existing-client",
"client_secret": "existing-secret",
"client_id": "stored_server",
"client_secret": "dummy",
"redirect_uris": ["https://proxy.litellm.example/callback"],
}

View File

@@ -66,7 +66,7 @@ class TestMCPCustomFields:
assert mcp_info["priority"] == 10
assert mcp_info["tags"] == ["production", "api"]
def test_custom_fields_preserved_from_database(self):
async def test_custom_fields_preserved_from_database(self):
"""Test that custom fields in mcp_info are preserved when adding from database."""
manager = MCPServerManager()
@@ -92,7 +92,7 @@ class TestMCPCustomFields:
mock_server.mcp_access_groups = None
# Add server to manager
manager.add_update_server(mock_server)
await manager.add_update_server(mock_server)
# Get the added server
server = manager.get_mcp_server_by_id("test-server-id")

View File

@@ -58,7 +58,7 @@ class TestMCPServerManager:
updated_at=datetime.now(),
)
manager.add_update_server(stdio_server)
await manager.add_update_server(stdio_server)
# Verify server was added
assert "stdio-server-1" in manager.registry
@@ -1265,7 +1265,7 @@ class TestMCPServerManager:
"env": {},
},
)
manager.add_update_server(server)
await manager.add_update_server(server)
assert server.server_id in manager.get_registry()
@pytest.mark.asyncio

View File

@@ -2,7 +2,7 @@ import json
import os
import sys
from litellm._uuid import uuid
from datetime import datetime
from datetime import datetime, timedelta
from typing import List
from unittest.mock import AsyncMock, MagicMock, patch
@@ -19,6 +19,7 @@ from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LitellmUserRoles,
MCPTransport,
NewMCPServerRequest,
UserAPIKeyAuth,
)
from litellm.types.mcp import MCPAuth
@@ -854,3 +855,316 @@ class TestMCPHealthCheckEndpoints:
assert server.last_health_check is not None
assert server.health_check_error is None
assert server.credentials is None
class TestTemporaryMCPSessionEndpoints:
def test_inherit_credentials_from_existing_server(self):
payload = NewMCPServerRequest(
server_id="server-123",
alias="Temp Server",
url="https://temp.example.com",
transport=MCPTransport.http,
)
existing_server = MagicMock()
existing_server.authentication_token = "token-abc"
existing_server.client_id = "client-123"
existing_server.client_secret = "secret-xyz"
existing_server.scopes = ["scope:a", "scope:b"]
mock_manager = MagicMock()
mock_manager.get_mcp_server_by_id.return_value = existing_server
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
mock_manager,
):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
_inherit_credentials_from_existing_server,
)
updated_payload = _inherit_credentials_from_existing_server(payload)
assert updated_payload.credentials == {
"auth_value": "token-abc",
"client_id": "client-123",
"client_secret": "secret-xyz",
"scopes": ["scope:a", "scope:b"],
}
mock_manager.get_mcp_server_by_id.assert_called_once_with("server-123")
def test_cache_temporary_mcp_server_stores_entry_with_ttl(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
_cache_temporary_mcp_server,
)
server = generate_mock_mcp_server_config_record(server_id="temp-cache")
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._temporary_mcp_servers",
{},
) as cache:
cached_server = _cache_temporary_mcp_server(server, ttl_seconds=2)
assert cached_server is server
assert "temp-cache" in cache
assert cache["temp-cache"].server is server
assert cache["temp-cache"].expires_at > datetime.utcnow()
def test_get_cached_temporary_mcp_server_prunes_expired_entries(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
_TemporaryMCPServerEntry,
get_cached_temporary_mcp_server,
)
server = generate_mock_mcp_server_config_record(server_id="expired")
expired_entry = _TemporaryMCPServerEntry(
server=server,
expires_at=datetime.utcnow() - timedelta(seconds=30),
)
cache = {"expired": expired_entry}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._temporary_mcp_servers",
cache,
):
result = get_cached_temporary_mcp_server("expired")
assert result is None
assert "expired" not in cache
def test_get_cached_temporary_mcp_server_or_404(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
_get_cached_temporary_mcp_server_or_404,
)
server = generate_mock_mcp_server_config_record(server_id="cached")
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_cached_temporary_mcp_server",
return_value=server,
) as get_cached:
result = _get_cached_temporary_mcp_server_or_404("cached")
assert result is server
get_cached.assert_called_once_with("cached")
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_cached_temporary_mcp_server",
return_value=None,
):
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
_get_cached_temporary_mcp_server_or_404("missing")
assert exc_info.value.status_code == 404
@pytest.mark.asyncio
async def test_add_session_mcp_server_caches_and_redacts_credentials(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
TEMPORARY_MCP_SERVER_TTL_SECONDS,
add_session_mcp_server,
)
payload = NewMCPServerRequest(
server_id="temp-server",
alias="Temporary",
url="https://temp.example.com",
transport=MCPTransport.http,
)
user_auth = generate_mock_user_api_key_auth(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id="admin-user",
)
inherited_server = MagicMock(
authentication_token="token-abc",
client_id="client-id",
client_secret="client-secret",
scopes=["scope1"],
)
built_server = generate_mock_mcp_server_config_record(server_id="temp-server")
mock_manager = MagicMock()
mock_manager.get_mcp_server_by_id.return_value = inherited_server
mock_manager.build_mcp_server_from_table = AsyncMock(return_value=built_server)
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
MagicMock(),
) as validate_mock, patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
mock_manager,
), patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._cache_temporary_mcp_server",
MagicMock(),
) as cache_mock:
response = await add_session_mcp_server(
payload=payload,
user_api_key_dict=user_auth,
)
validate_mock.assert_called_once_with(payload)
mock_manager.build_mcp_server_from_table.assert_awaited_once()
cache_mock.assert_called_once_with(
built_server, ttl_seconds=TEMPORARY_MCP_SERVER_TTL_SECONDS
)
args, _ = mock_manager.build_mcp_server_from_table.call_args
temp_record = args[0]
assert temp_record.credentials == {
"auth_value": "token-abc",
"client_id": "client-id",
"client_secret": "client-secret",
"scopes": ["scope1"],
}
assert response.credentials is None
@pytest.mark.asyncio
async def test_add_session_mcp_server_rejects_non_admins(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
add_session_mcp_server,
)
payload = NewMCPServerRequest(
alias="Temporary",
server_id="temp-server",
url="https://temp.example.com",
transport=MCPTransport.http,
)
non_admin = generate_mock_user_api_key_auth(
user_role=LitellmUserRoles.INTERNAL_USER,
)
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
MagicMock(),
):
with pytest.raises(Exception) as exc_info:
await add_session_mcp_server(
payload=payload,
user_api_key_dict=non_admin,
)
assert "permission" in str(exc_info.value)
@pytest.mark.asyncio
async def test_mcp_authorize_proxies_to_discoverable_endpoint(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
mcp_authorize,
)
request = MagicMock()
server = generate_mock_mcp_server_config_record(server_id="server-1")
authorize_response = MagicMock()
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
return_value=server,
) as get_server, patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.authorize_with_server",
AsyncMock(return_value=authorize_response),
) as authorize_mock:
result = await mcp_authorize(
request=request,
server_id="server-1",
client_id="client-id",
redirect_uri="https://example.com/callback",
state="state123",
code_challenge="challenge",
code_challenge_method="S256",
response_type="code",
scope="scope1",
)
assert result is authorize_response
get_server.assert_called_once_with("server-1")
authorize_mock.assert_awaited_once_with(
request=request,
mcp_server=server,
client_id="client-id",
redirect_uri="https://example.com/callback",
state="state123",
code_challenge="challenge",
code_challenge_method="S256",
response_type="code",
scope="scope1",
)
@pytest.mark.asyncio
async def test_mcp_token_proxies_to_exchange_endpoint(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
mcp_token,
)
request = MagicMock()
server = generate_mock_mcp_server_config_record(server_id="server-1")
exchange_response = {"access_token": "token"}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
return_value=server,
) as get_server, patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.exchange_token_with_server",
AsyncMock(return_value=exchange_response),
) as exchange_mock:
result = await mcp_token(
request=request,
server_id="server-1",
grant_type="authorization_code",
code="code-123",
redirect_uri="https://example.com/callback",
client_id="client",
client_secret="secret",
code_verifier="verifier",
)
assert result is exchange_response
get_server.assert_called_once_with("server-1")
exchange_mock.assert_awaited_once_with(
request=request,
mcp_server=server,
grant_type="authorization_code",
code="code-123",
redirect_uri="https://example.com/callback",
client_id="client",
client_secret="secret",
code_verifier="verifier",
)
@pytest.mark.asyncio
async def test_mcp_register_proxies_request_body(self):
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
mcp_register,
)
request = MagicMock()
server = generate_mock_mcp_server_config_record(server_id="server-1")
register_response = {"client_id": "generated"}
request_body = {
"client_name": "LiteLLM",
"grant_types": ["authorization_code"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_basic",
}
with patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
return_value=server,
) as get_server, patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints._read_request_body",
AsyncMock(return_value=request_body),
) as read_body, patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.register_client_with_server",
AsyncMock(return_value=register_response),
) as register_mock:
result = await mcp_register(request=request, server_id="server-1")
assert result is register_response
get_server.assert_called_once_with("server-1")
read_body.assert_awaited_once_with(request=request)
register_mock.assert_awaited_once_with(
request=request,
mcp_server=server,
client_name="LiteLLM",
grant_types=["authorization_code"],
response_types=["code"],
token_endpoint_auth_method="client_secret_basic",
fallback_client_id="server-1",
)