mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
feat: add backend support for OAuth2 auth_type registration via UI (#17006)
This commit is contained in:
@@ -14,6 +14,7 @@ from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
router = APIRouter(
|
||||
tags=["mcp"],
|
||||
@@ -122,6 +123,163 @@ def decode_state_hash(encrypted_state: str) -> dict:
|
||||
return state_data
|
||||
|
||||
|
||||
async def authorize_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str = "",
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
if mcp_server.auth_type != "oauth2":
|
||||
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
parsed = urlparse(redirect_uri)
|
||||
base_url = urlunparse(parsed._replace(query=""))
|
||||
request_base_url = get_request_base_url(request)
|
||||
encoded_state = encode_state_with_base_url(
|
||||
base_url=base_url,
|
||||
original_state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
client_redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"redirect_uri": f"{request_base_url}/callback",
|
||||
"state": encoded_state,
|
||||
"response_type": response_type or "code",
|
||||
}
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
elif mcp_server.scopes:
|
||||
params["scope"] = " ".join(mcp_server.scopes)
|
||||
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
if code_challenge_method:
|
||||
params["code_challenge_method"] = code_challenge_method
|
||||
|
||||
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
|
||||
|
||||
|
||||
async def exchange_token_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
grant_type: str,
|
||||
code: Optional[str],
|
||||
redirect_uri: Optional[str],
|
||||
client_id: str,
|
||||
client_secret: Optional[str],
|
||||
code_verifier: Optional[str],
|
||||
):
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="Unsupported grant_type")
|
||||
|
||||
if mcp_server.token_url is None:
|
||||
raise HTTPException(status_code=400, detail="MCP server token url is not set")
|
||||
|
||||
proxy_base_url = get_request_base_url(request)
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"client_secret": mcp_server.client_secret
|
||||
if mcp_server.client_secret
|
||||
else client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
}
|
||||
|
||||
if code_verifier:
|
||||
token_data["code_verifier"] = code_verifier
|
||||
|
||||
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
response = await async_client.post(
|
||||
mcp_server.token_url,
|
||||
headers={"Accept": "application/json"},
|
||||
data=token_data,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
token_response = response.json()
|
||||
access_token = token_response["access_token"]
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_response.get("token_type", "Bearer"),
|
||||
"expires_in": token_response.get("expires_in", 3600),
|
||||
}
|
||||
|
||||
if "refresh_token" in token_response and token_response["refresh_token"]:
|
||||
result["refresh_token"] = token_response["refresh_token"]
|
||||
if "scope" in token_response and token_response["scope"]:
|
||||
result["scope"] = token_response["scope"]
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
async def register_client_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_name: str,
|
||||
grant_types: Optional[list],
|
||||
response_types: Optional[list],
|
||||
token_endpoint_auth_method: Optional[str],
|
||||
fallback_client_id: Optional[str] = None,
|
||||
):
|
||||
request_base_url = get_request_base_url(request)
|
||||
dummy_return = {
|
||||
"client_id": fallback_client_id or mcp_server.server_name,
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
|
||||
if mcp_server.client_id and mcp_server.client_secret:
|
||||
return dummy_return
|
||||
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
if mcp_server.registration_url is None:
|
||||
return dummy_return
|
||||
|
||||
register_data = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
"grant_types": grant_types or [],
|
||||
"response_types": response_types or [],
|
||||
"token_endpoint_auth_method": token_endpoint_auth_method or "",
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Oauth2Register
|
||||
)
|
||||
response = await async_client.post(
|
||||
mcp_server.registration_url,
|
||||
headers=headers,
|
||||
json=register_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_response = response.json()
|
||||
|
||||
return JSONResponse(token_response)
|
||||
|
||||
|
||||
@router.get("/{mcp_server_name}/authorize")
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
@@ -140,53 +298,21 @@ async def authorize(
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
if mcp_server_name:
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
|
||||
else:
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(client_id)
|
||||
lookup_name = mcp_server_name or client_id
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
if mcp_server.auth_type != "oauth2":
|
||||
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
# Parse it to remove any existing query
|
||||
parsed = urlparse(redirect_uri)
|
||||
base_url = urlunparse(parsed._replace(query=""))
|
||||
|
||||
# Get the correct base URL considering X-Forwarded-* headers
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
# Encode the base_url, original state, PKCE params, and client redirect_uri in encrypted state
|
||||
encoded_state = encode_state_with_base_url(
|
||||
base_url=base_url,
|
||||
original_state=state,
|
||||
return await authorize_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
client_redirect_uri=redirect_uri,
|
||||
response_type=response_type,
|
||||
scope=scope,
|
||||
)
|
||||
# Build params for upstream OAuth provider
|
||||
params = {
|
||||
"client_id": client_id if client_id else mcp_server.client_id,
|
||||
"redirect_uri": f"{request_base_url}/callback",
|
||||
"state": encoded_state,
|
||||
"response_type": response_type or "code",
|
||||
}
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
elif mcp_server.scopes:
|
||||
params["scope"] = " ".join(mcp_server.scopes)
|
||||
|
||||
# Forward PKCE parameters if present
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
if code_challenge_method:
|
||||
params["code_challenge_method"] = code_challenge_method
|
||||
|
||||
return RedirectResponse(f"{mcp_server.authorization_url}?{urlencode(params)}")
|
||||
|
||||
|
||||
@router.post("/{mcp_server_name}/token")
|
||||
@@ -214,64 +340,21 @@ async def token_endpoint(
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
if mcp_server_name:
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
|
||||
else:
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(client_id)
|
||||
|
||||
lookup_name = mcp_server_name or client_id
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(lookup_name)
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="Unsupported grant_type")
|
||||
|
||||
if mcp_server.token_url is None:
|
||||
raise HTTPException(status_code=400, detail="MCP server token url is not set")
|
||||
|
||||
# Get the correct base URL considering X-Forwarded-* headers
|
||||
proxy_base_url = get_request_base_url(request)
|
||||
|
||||
# Build token request data
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": client_id if client_id else mcp_server.client_id,
|
||||
"client_secret": client_secret if client_secret else mcp_server.client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
}
|
||||
|
||||
# Forward PKCE code_verifier if present
|
||||
if code_verifier:
|
||||
token_data["code_verifier"] = code_verifier
|
||||
|
||||
# Exchange code for real OAuth token
|
||||
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
response = await async_client.post(
|
||||
mcp_server.token_url,
|
||||
headers={"Accept": "application/json"},
|
||||
data=token_data,
|
||||
return await exchange_token_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
grant_type=grant_type,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
token_response = response.json()
|
||||
access_token = token_response["access_token"]
|
||||
|
||||
# Return to client in expected OAuth 2 format
|
||||
# Only include fields that have values
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_response.get("token_type", "Bearer"),
|
||||
"expires_in": token_response.get("expires_in", 3600),
|
||||
}
|
||||
|
||||
# Add optional fields only if they exist
|
||||
if "refresh_token" in token_response and token_response["refresh_token"]:
|
||||
result["refresh_token"] = token_response["refresh_token"]
|
||||
if "scope" in token_response and token_response["scope"]:
|
||||
result["scope"] = token_response["scope"]
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback(code: str, state: str):
|
||||
@@ -391,44 +474,12 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(mcp_server_name)
|
||||
if mcp_server is None:
|
||||
return dummy_return
|
||||
|
||||
if mcp_server.client_id and mcp_server.client_secret:
|
||||
return {
|
||||
"client_id": mcp_server.client_id,
|
||||
"client_secret": mcp_server.client_secret,
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
if mcp_server.registration_url is None:
|
||||
return dummy_return
|
||||
|
||||
register_data = {
|
||||
"client_name": data.get("client_name", ""),
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
"grant_types": data.get("grant_types", []),
|
||||
"response_types": data.get("response_types", []),
|
||||
"token_endpoint_auth_method": data.get("token_endpoint_auth_method", ""),
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Oauth2Register
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=mcp_server_name,
|
||||
)
|
||||
response = await async_client.post(
|
||||
mcp_server.registration_url,
|
||||
headers=headers,
|
||||
json=register_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_response = response.json()
|
||||
|
||||
return JSONResponse(token_response)
|
||||
|
||||
@@ -395,12 +395,12 @@ class MCPServerManager:
|
||||
)
|
||||
|
||||
# Update tool name to server name mapping (for both prefixed and base names)
|
||||
self.tool_name_to_mcp_server_name_mapping[base_tool_name] = (
|
||||
server_prefix
|
||||
)
|
||||
self.tool_name_to_mcp_server_name_mapping[prefixed_tool_name] = (
|
||||
server_prefix
|
||||
)
|
||||
self.tool_name_to_mcp_server_name_mapping[
|
||||
base_tool_name
|
||||
] = server_prefix
|
||||
self.tool_name_to_mcp_server_name_mapping[
|
||||
prefixed_tool_name
|
||||
] = server_prefix
|
||||
|
||||
registered_count += 1
|
||||
verbose_logger.debug(
|
||||
@@ -432,73 +432,127 @@ class MCPServerManager:
|
||||
f"Server ID {mcp_server.server_id} not found in registry"
|
||||
)
|
||||
|
||||
def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
|
||||
async def build_mcp_server_from_table(
|
||||
self,
|
||||
mcp_server: LiteLLM_MCPServerTable,
|
||||
*,
|
||||
credentials_are_encrypted: bool = True,
|
||||
) -> MCPServer:
|
||||
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
|
||||
env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None))
|
||||
static_headers_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "static_headers", None)
|
||||
)
|
||||
credentials_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "credentials", None)
|
||||
)
|
||||
|
||||
encrypted_auth_value: Optional[str] = None
|
||||
encrypted_client_id: Optional[str] = None
|
||||
encrypted_client_secret: Optional[str] = None
|
||||
if credentials_dict:
|
||||
encrypted_auth_value = credentials_dict.get("auth_value")
|
||||
encrypted_client_id = credentials_dict.get("client_id")
|
||||
encrypted_client_secret = credentials_dict.get("client_secret")
|
||||
|
||||
auth_value: Optional[str] = None
|
||||
if encrypted_auth_value:
|
||||
if credentials_are_encrypted:
|
||||
auth_value = decrypt_value_helper(
|
||||
value=encrypted_auth_value,
|
||||
key="auth_value",
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
else:
|
||||
auth_value = encrypted_auth_value
|
||||
|
||||
client_id_value: Optional[str] = None
|
||||
if encrypted_client_id:
|
||||
if credentials_are_encrypted:
|
||||
client_id_value = decrypt_value_helper(
|
||||
value=encrypted_client_id,
|
||||
key="client_id",
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
else:
|
||||
client_id_value = encrypted_client_id
|
||||
|
||||
client_secret_value: Optional[str] = None
|
||||
if encrypted_client_secret:
|
||||
if credentials_are_encrypted:
|
||||
client_secret_value = decrypt_value_helper(
|
||||
value=encrypted_client_secret,
|
||||
key="client_secret",
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
else:
|
||||
client_secret_value = encrypted_client_secret
|
||||
|
||||
scopes: Optional[List[str]] = None
|
||||
if credentials_dict:
|
||||
scopes_value = credentials_dict.get("scopes")
|
||||
if scopes_value is not None:
|
||||
scopes = self._extract_scopes(scopes_value)
|
||||
|
||||
name_for_prefix = (
|
||||
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
|
||||
)
|
||||
|
||||
mcp_info: MCPInfo = _mcp_info.copy()
|
||||
if "server_name" not in mcp_info:
|
||||
mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id
|
||||
if "description" not in mcp_info and mcp_server.description:
|
||||
mcp_info["description"] = mcp_server.description
|
||||
|
||||
auth_type = cast(MCPAuthType, mcp_server.auth_type)
|
||||
if mcp_server.url and auth_type == MCPAuth.oauth2:
|
||||
mcp_oauth_metadata = await self._descovery_metadata(
|
||||
server_url=mcp_server.url,
|
||||
)
|
||||
else:
|
||||
mcp_oauth_metadata = None
|
||||
|
||||
resolved_scopes = scopes or (
|
||||
mcp_oauth_metadata.scopes if mcp_oauth_metadata else None
|
||||
)
|
||||
|
||||
new_server = MCPServer(
|
||||
server_id=mcp_server.server_id,
|
||||
name=name_for_prefix,
|
||||
alias=getattr(mcp_server, "alias", None),
|
||||
server_name=getattr(mcp_server, "server_name", None),
|
||||
url=mcp_server.url,
|
||||
transport=cast(MCPTransportType, mcp_server.transport),
|
||||
auth_type=auth_type,
|
||||
authentication_token=auth_value,
|
||||
mcp_info=mcp_info,
|
||||
extra_headers=getattr(mcp_server, "extra_headers", None),
|
||||
static_headers=static_headers_dict,
|
||||
client_id=client_id_value or getattr(mcp_server, "client_id", None),
|
||||
client_secret=client_secret_value
|
||||
or getattr(mcp_server, "client_secret", None),
|
||||
scopes=resolved_scopes,
|
||||
authorization_url=getattr(mcp_oauth_metadata, "authorization_url", None),
|
||||
token_url=getattr(mcp_oauth_metadata, "token_url", None),
|
||||
registration_url=getattr(mcp_oauth_metadata, "registration_url", None),
|
||||
command=getattr(mcp_server, "command", None),
|
||||
args=getattr(mcp_server, "args", None) or [],
|
||||
env=env_dict,
|
||||
access_groups=getattr(mcp_server, "mcp_access_groups", None),
|
||||
allowed_tools=getattr(mcp_server, "allowed_tools", None),
|
||||
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
|
||||
)
|
||||
return new_server
|
||||
|
||||
async def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
|
||||
try:
|
||||
if mcp_server.server_id not in self.get_registry():
|
||||
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
|
||||
# Use helper to deserialize dictionary
|
||||
# Safely access env field which may not exist on Prisma model objects
|
||||
env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None))
|
||||
static_headers_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "static_headers", None)
|
||||
)
|
||||
credentials_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "credentials", None)
|
||||
)
|
||||
|
||||
encrypted_auth_value: Optional[str] = None
|
||||
if credentials_dict:
|
||||
encrypted_auth_value = credentials_dict.get("auth_value")
|
||||
|
||||
auth_value: Optional[str] = None
|
||||
if encrypted_auth_value:
|
||||
auth_value = decrypt_value_helper(
|
||||
value=encrypted_auth_value,
|
||||
key="auth_value",
|
||||
)
|
||||
# Use alias for name if present, else server_name
|
||||
name_for_prefix = (
|
||||
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
|
||||
)
|
||||
# Preserve all custom fields from database while setting defaults for core fields
|
||||
mcp_info: MCPInfo = _mcp_info.copy()
|
||||
# Set default values for core fields if not present
|
||||
if "server_name" not in mcp_info:
|
||||
mcp_info["server_name"] = (
|
||||
mcp_server.server_name or mcp_server.server_id
|
||||
)
|
||||
if "description" not in mcp_info and mcp_server.description:
|
||||
mcp_info["description"] = mcp_server.description
|
||||
|
||||
new_server = MCPServer(
|
||||
server_id=mcp_server.server_id,
|
||||
name=name_for_prefix,
|
||||
alias=getattr(mcp_server, "alias", None),
|
||||
server_name=getattr(mcp_server, "server_name", None),
|
||||
url=mcp_server.url,
|
||||
transport=cast(MCPTransportType, mcp_server.transport),
|
||||
auth_type=cast(MCPAuthType, mcp_server.auth_type),
|
||||
authentication_token=auth_value,
|
||||
mcp_info=mcp_info,
|
||||
extra_headers=getattr(mcp_server, "extra_headers", None),
|
||||
static_headers=static_headers_dict,
|
||||
# oauth specific fields
|
||||
client_id=getattr(mcp_server, "client_id", None),
|
||||
client_secret=getattr(mcp_server, "client_secret", None),
|
||||
scopes=getattr(mcp_server, "scopes", None),
|
||||
authorization_url=getattr(mcp_server, "authorization_url", None),
|
||||
token_url=getattr(mcp_server, "token_url", None),
|
||||
registration_url=getattr(mcp_server, "registration_url", None),
|
||||
# Stdio-specific fields
|
||||
command=getattr(mcp_server, "command", None),
|
||||
args=getattr(mcp_server, "args", None) or [],
|
||||
env=env_dict,
|
||||
access_groups=getattr(mcp_server, "mcp_access_groups", None),
|
||||
allowed_tools=getattr(mcp_server, "allowed_tools", None),
|
||||
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
|
||||
)
|
||||
if mcp_server.server_id not in self.registry:
|
||||
new_server = await self.build_mcp_server_from_table(mcp_server)
|
||||
self.registry[mcp_server.server_id] = new_server
|
||||
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")
|
||||
verbose_logger.debug(f"Added MCP Server: {new_server.name}")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Failed to add MCP server: {str(e)}")
|
||||
|
||||
@@ -293,7 +293,11 @@ if MCP_AVAILABLE:
|
||||
NewMCPServerRequest,
|
||||
)
|
||||
|
||||
async def _execute_with_mcp_client(request: NewMCPServerRequest, operation):
|
||||
async def _execute_with_mcp_client(
|
||||
request: NewMCPServerRequest,
|
||||
operation,
|
||||
oauth2_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Common helper to create MCP client, execute operation, and ensure proper cleanup.
|
||||
|
||||
@@ -315,6 +319,7 @@ if MCP_AVAILABLE:
|
||||
mcp_info=request.mcp_info,
|
||||
),
|
||||
mcp_auth_header=None,
|
||||
extra_headers=oauth2_headers,
|
||||
)
|
||||
|
||||
return await operation(client)
|
||||
@@ -342,12 +347,19 @@ if MCP_AVAILABLE:
|
||||
|
||||
@router.post("/test/tools/list")
|
||||
async def test_tools_list(
|
||||
request: NewMCPServerRequest,
|
||||
request: Request,
|
||||
new_mcp_server_request: NewMCPServerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Preview tools available from MCP server before adding it
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
headers = request.headers
|
||||
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers)
|
||||
|
||||
async def _list_tools_operation(client):
|
||||
async def _list_tools_session_operation(session):
|
||||
@@ -366,4 +378,6 @@ if MCP_AVAILABLE:
|
||||
"message": "Successfully retrieved tools",
|
||||
}
|
||||
|
||||
return await _execute_with_mcp_client(request, _list_tools_operation)
|
||||
return await _execute_with_mcp_client(
|
||||
new_mcp_server_request, _list_tools_operation, oauth2_headers
|
||||
)
|
||||
|
||||
@@ -14,13 +14,24 @@ Endpoints here:
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from datetime import datetime
|
||||
from typing import Iterable, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
Form,
|
||||
Header,
|
||||
HTTPException,
|
||||
Request,
|
||||
Response,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
import litellm
|
||||
from litellm._uuid import uuid
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._experimental.mcp_server.utils import (
|
||||
@@ -29,6 +40,7 @@ from litellm.proxy._experimental.mcp_server.utils import (
|
||||
|
||||
router = APIRouter(prefix="/v1/mcp", tags=["mcp"])
|
||||
MCP_AVAILABLE: bool = True
|
||||
TEMPORARY_MCP_SERVER_TTL_SECONDS = 300
|
||||
try:
|
||||
importlib.import_module("mcp")
|
||||
except ImportError as e:
|
||||
@@ -43,9 +55,15 @@ if MCP_AVAILABLE:
|
||||
get_mcp_server,
|
||||
update_mcp_server,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
authorize_with_server,
|
||||
exchange_token_with_server,
|
||||
register_client_with_server,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
LitellmUserRoles,
|
||||
@@ -58,6 +76,47 @@ if MCP_AVAILABLE:
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
@dataclass
|
||||
class _TemporaryMCPServerEntry:
|
||||
server: MCPServer
|
||||
expires_at: datetime
|
||||
|
||||
_temporary_mcp_servers: Dict[str, _TemporaryMCPServerEntry] = {}
|
||||
|
||||
def _prune_expired_temporary_mcp_servers() -> None:
|
||||
if not _temporary_mcp_servers:
|
||||
return
|
||||
|
||||
now = datetime.utcnow()
|
||||
expired_ids = [
|
||||
server_id
|
||||
for server_id, entry in _temporary_mcp_servers.items()
|
||||
if entry.expires_at <= now
|
||||
]
|
||||
for server_id in expired_ids:
|
||||
_temporary_mcp_servers.pop(server_id, None)
|
||||
|
||||
def _cache_temporary_mcp_server(server: MCPServer, ttl_seconds: int) -> MCPServer:
|
||||
ttl_seconds = max(1, ttl_seconds)
|
||||
_prune_expired_temporary_mcp_servers()
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=ttl_seconds)
|
||||
_temporary_mcp_servers[server.server_id] = _TemporaryMCPServerEntry(
|
||||
server=server,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return server
|
||||
|
||||
def get_cached_temporary_mcp_server(
|
||||
server_id: str,
|
||||
) -> Optional[MCPServer]:
|
||||
_prune_expired_temporary_mcp_servers()
|
||||
entry = _temporary_mcp_servers.get(server_id)
|
||||
if entry is None:
|
||||
return None
|
||||
return entry.server
|
||||
|
||||
def _redact_mcp_credentials(
|
||||
mcp_server: LiteLLM_MCPServerTable,
|
||||
@@ -79,6 +138,75 @@ if MCP_AVAILABLE:
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
return [_redact_mcp_credentials(server) for server in mcp_servers]
|
||||
|
||||
def _inherit_credentials_from_existing_server(
|
||||
payload: NewMCPServerRequest,
|
||||
) -> NewMCPServerRequest:
|
||||
if not payload.server_id or payload.credentials:
|
||||
return payload
|
||||
|
||||
existing_server = global_mcp_server_manager.get_mcp_server_by_id(
|
||||
payload.server_id
|
||||
)
|
||||
if existing_server is None:
|
||||
return payload
|
||||
|
||||
inherited_credentials: MCPCredentials = {}
|
||||
if existing_server.authentication_token:
|
||||
inherited_credentials["auth_value"] = existing_server.authentication_token
|
||||
if existing_server.client_id:
|
||||
inherited_credentials["client_id"] = existing_server.client_id
|
||||
if existing_server.client_secret:
|
||||
inherited_credentials["client_secret"] = existing_server.client_secret
|
||||
if existing_server.scopes:
|
||||
inherited_credentials["scopes"] = existing_server.scopes
|
||||
|
||||
if not inherited_credentials:
|
||||
return payload
|
||||
|
||||
try:
|
||||
return payload.model_copy(update={"credentials": inherited_credentials})
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
payload_dict: Dict[str, Any]
|
||||
try:
|
||||
payload_dict = payload.model_dump() # type: ignore[attr-defined]
|
||||
except AttributeError:
|
||||
payload_dict = payload.dict() # type: ignore[attr-defined]
|
||||
payload_dict["credentials"] = inherited_credentials
|
||||
return NewMCPServerRequest(**payload_dict)
|
||||
|
||||
def _build_temporary_mcp_server_record(
|
||||
payload: NewMCPServerRequest,
|
||||
created_by: Optional[str],
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
now = datetime.utcnow()
|
||||
server_id = payload.server_id or str(uuid.uuid4())
|
||||
server_name = payload.server_name or payload.alias or server_id
|
||||
return LiteLLM_MCPServerTable(
|
||||
server_id=server_id,
|
||||
server_name=server_name,
|
||||
alias=payload.alias,
|
||||
description=payload.description,
|
||||
url=payload.url,
|
||||
transport=payload.transport,
|
||||
auth_type=payload.auth_type,
|
||||
credentials=payload.credentials,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
created_by=created_by,
|
||||
updated_by=created_by,
|
||||
teams=[],
|
||||
mcp_access_groups=payload.mcp_access_groups,
|
||||
allowed_tools=payload.allowed_tools or [],
|
||||
extra_headers=payload.extra_headers or [],
|
||||
mcp_info=payload.mcp_info,
|
||||
static_headers=payload.static_headers,
|
||||
command=payload.command,
|
||||
args=payload.args,
|
||||
env=payload.env,
|
||||
)
|
||||
|
||||
def get_prisma_client_or_throw(message: str):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
@@ -376,7 +504,7 @@ if MCP_AVAILABLE:
|
||||
exists = does_mcp_server_exist(mcp_server_records, server_id)
|
||||
|
||||
if exists:
|
||||
global_mcp_server_manager.add_update_server(mcp_server)
|
||||
await global_mcp_server_manager.add_update_server(mcp_server)
|
||||
return _redact_mcp_credentials(mcp_server)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@@ -450,7 +578,7 @@ if MCP_AVAILABLE:
|
||||
payload,
|
||||
touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
global_mcp_server_manager.add_update_server(new_mcp_server)
|
||||
await global_mcp_server_manager.add_update_server(new_mcp_server)
|
||||
|
||||
# Ensure registry is up to date by reloading from database
|
||||
await global_mcp_server_manager.reload_servers_from_database()
|
||||
@@ -462,6 +590,151 @@ if MCP_AVAILABLE:
|
||||
)
|
||||
return _redact_mcp_credentials(new_mcp_server)
|
||||
|
||||
@router.post(
|
||||
"/server/oauth/session",
|
||||
description="Temporarily cache an MCP server in memory without writing to the database",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def add_session_mcp_server(
|
||||
payload: NewMCPServerRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Cache MCP server info in memory for a short duration (~5 minutes).
|
||||
|
||||
This endpoint does not write to the database. If the same server_id is provided
|
||||
again while the cache entry is active, it will refresh the cached data + TTL.
|
||||
"""
|
||||
|
||||
# Validate and normalize payload fields (alias/server name rules)
|
||||
validate_and_normalize_mcp_server_payload(payload)
|
||||
|
||||
# Restrict to proxy admins similar to the persistent create endpoint
|
||||
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "User does not have permission to create temporary mcp servers. You can only create temporary mcp servers if you are a PROXY_ADMIN."
|
||||
},
|
||||
)
|
||||
|
||||
created_by = user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME
|
||||
payload_with_credentials = _inherit_credentials_from_existing_server(payload)
|
||||
temp_record = _build_temporary_mcp_server_record(
|
||||
payload_with_credentials,
|
||||
created_by,
|
||||
)
|
||||
|
||||
try:
|
||||
temporary_server = (
|
||||
await global_mcp_server_manager.build_mcp_server_from_table(
|
||||
temp_record,
|
||||
credentials_are_encrypted=False,
|
||||
)
|
||||
)
|
||||
_cache_temporary_mcp_server(
|
||||
temporary_server,
|
||||
ttl_seconds=TEMPORARY_MCP_SERVER_TTL_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error caching temporary mcp server: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Error caching temporary mcp server: {str(e)}"},
|
||||
)
|
||||
|
||||
return _redact_mcp_credentials(temp_record)
|
||||
|
||||
def _get_cached_temporary_mcp_server_or_404(server_id: str) -> MCPServer:
|
||||
server = get_cached_temporary_mcp_server(server_id)
|
||||
if server is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"error": f"Temporary MCP server {server_id} not found"},
|
||||
)
|
||||
return server
|
||||
|
||||
@router.get(
|
||||
"/server/oauth/{server_id}/authorize",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def mcp_authorize(
|
||||
request: Request,
|
||||
server_id: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str = "",
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
|
||||
return await authorize_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_id=client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
response_type=response_type,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/server/oauth/{server_id}/token",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def mcp_token(
|
||||
request: Request,
|
||||
server_id: str,
|
||||
grant_type: str = Form(...),
|
||||
code: Optional[str] = Form(None),
|
||||
redirect_uri: Optional[str] = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
code_verifier: Optional[str] = Form(None),
|
||||
):
|
||||
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
|
||||
return await exchange_token_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
grant_type=grant_type,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/server/oauth/{server_id}/register",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def mcp_register(request: Request, server_id: str):
|
||||
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=server_id,
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/server/{server_id}",
|
||||
description="Allows deleting mcp serves in the db",
|
||||
@@ -586,7 +859,7 @@ if MCP_AVAILABLE:
|
||||
"error": f"MCP Server not found, passed server_id={payload.server_id}"
|
||||
},
|
||||
)
|
||||
global_mcp_server_manager.add_update_server(mcp_server_record_updated)
|
||||
await global_mcp_server_manager.add_update_server(mcp_server_record_updated)
|
||||
|
||||
# Ensure registry is up to date by reloading from database
|
||||
await global_mcp_server_manager.reload_servers_from_database()
|
||||
|
||||
@@ -1048,7 +1048,7 @@ async def test_mcp_server_manager_config_integration_with_database():
|
||||
)
|
||||
|
||||
# Test the add_update_server method (this tests our fix)
|
||||
test_manager.add_update_server(db_server)
|
||||
await test_manager.add_update_server(db_server)
|
||||
|
||||
# Verify the server was added with correct access_groups
|
||||
registry = test_manager.get_registry()
|
||||
@@ -1342,7 +1342,8 @@ async def test_mcp_server_manager_server_id_tool_prefixing():
|
||||
)
|
||||
|
||||
|
||||
def test_add_update_server_with_alias():
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_update_server_with_alias():
|
||||
"""
|
||||
Test that add_update_server correctly handles servers with alias.
|
||||
"""
|
||||
@@ -1371,7 +1372,7 @@ def test_add_update_server_with_alias():
|
||||
mock_mcp_server.token_url = None
|
||||
|
||||
# Add server to manager
|
||||
test_manager.add_update_server(mock_mcp_server)
|
||||
await test_manager.add_update_server(mock_mcp_server)
|
||||
|
||||
# Verify server was added with correct name (should use alias)
|
||||
assert "test-server-123" in test_manager.registry
|
||||
@@ -1381,7 +1382,8 @@ def test_add_update_server_with_alias():
|
||||
assert added_server.server_name == "Test Server"
|
||||
|
||||
|
||||
def test_add_update_server_without_alias():
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_update_server_without_alias():
|
||||
"""
|
||||
Test that add_update_server correctly handles servers without alias.
|
||||
"""
|
||||
@@ -1410,7 +1412,7 @@ def test_add_update_server_without_alias():
|
||||
mock_mcp_server.token_url = None
|
||||
|
||||
# Add server to manager
|
||||
test_manager.add_update_server(mock_mcp_server)
|
||||
await test_manager.add_update_server(mock_mcp_server)
|
||||
|
||||
# Verify server was added with correct name (should use server_name)
|
||||
assert "test-server-123" in test_manager.registry
|
||||
@@ -1420,7 +1422,8 @@ def test_add_update_server_without_alias():
|
||||
assert added_server.server_name == "Test Server"
|
||||
|
||||
|
||||
def test_add_update_server_fallback_to_server_id():
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_update_server_fallback_to_server_id():
|
||||
"""
|
||||
Test that add_update_server falls back to server_id when neither alias nor server_name are available.
|
||||
"""
|
||||
@@ -1449,7 +1452,7 @@ def test_add_update_server_fallback_to_server_id():
|
||||
mock_mcp_server.token_url = None
|
||||
|
||||
# Add server to manager
|
||||
test_manager.add_update_server(mock_mcp_server)
|
||||
await test_manager.add_update_server(mock_mcp_server)
|
||||
|
||||
# Verify server was added with correct name (should use server_id)
|
||||
assert "test-server-123" in test_manager.registry
|
||||
|
||||
@@ -310,8 +310,8 @@ async def test_register_client_returns_existing_server_credentials():
|
||||
global_mcp_server_manager.registry.clear()
|
||||
|
||||
assert result == {
|
||||
"client_id": "existing-client",
|
||||
"client_secret": "existing-secret",
|
||||
"client_id": "stored_server",
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": ["https://proxy.litellm.example/callback"],
|
||||
}
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TestMCPCustomFields:
|
||||
assert mcp_info["priority"] == 10
|
||||
assert mcp_info["tags"] == ["production", "api"]
|
||||
|
||||
def test_custom_fields_preserved_from_database(self):
|
||||
async def test_custom_fields_preserved_from_database(self):
|
||||
"""Test that custom fields in mcp_info are preserved when adding from database."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestMCPCustomFields:
|
||||
mock_server.mcp_access_groups = None
|
||||
|
||||
# Add server to manager
|
||||
manager.add_update_server(mock_server)
|
||||
await manager.add_update_server(mock_server)
|
||||
|
||||
# Get the added server
|
||||
server = manager.get_mcp_server_by_id("test-server-id")
|
||||
|
||||
@@ -58,7 +58,7 @@ class TestMCPServerManager:
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
|
||||
manager.add_update_server(stdio_server)
|
||||
await manager.add_update_server(stdio_server)
|
||||
|
||||
# Verify server was added
|
||||
assert "stdio-server-1" in manager.registry
|
||||
@@ -1265,7 +1265,7 @@ class TestMCPServerManager:
|
||||
"env": {},
|
||||
},
|
||||
)
|
||||
manager.add_update_server(server)
|
||||
await manager.add_update_server(server)
|
||||
assert server.server_id in manager.get_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@@ -19,6 +19,7 @@ from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
LitellmUserRoles,
|
||||
MCPTransport,
|
||||
NewMCPServerRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
@@ -854,3 +855,316 @@ class TestMCPHealthCheckEndpoints:
|
||||
assert server.last_health_check is not None
|
||||
assert server.health_check_error is None
|
||||
assert server.credentials is None
|
||||
|
||||
|
||||
class TestTemporaryMCPSessionEndpoints:
|
||||
def test_inherit_credentials_from_existing_server(self):
|
||||
payload = NewMCPServerRequest(
|
||||
server_id="server-123",
|
||||
alias="Temp Server",
|
||||
url="https://temp.example.com",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
existing_server = MagicMock()
|
||||
existing_server.authentication_token = "token-abc"
|
||||
existing_server.client_id = "client-123"
|
||||
existing_server.client_secret = "secret-xyz"
|
||||
existing_server.scopes = ["scope:a", "scope:b"]
|
||||
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_mcp_server_by_id.return_value = existing_server
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
_inherit_credentials_from_existing_server,
|
||||
)
|
||||
|
||||
updated_payload = _inherit_credentials_from_existing_server(payload)
|
||||
|
||||
assert updated_payload.credentials == {
|
||||
"auth_value": "token-abc",
|
||||
"client_id": "client-123",
|
||||
"client_secret": "secret-xyz",
|
||||
"scopes": ["scope:a", "scope:b"],
|
||||
}
|
||||
mock_manager.get_mcp_server_by_id.assert_called_once_with("server-123")
|
||||
|
||||
def test_cache_temporary_mcp_server_stores_entry_with_ttl(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
_cache_temporary_mcp_server,
|
||||
)
|
||||
|
||||
server = generate_mock_mcp_server_config_record(server_id="temp-cache")
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._temporary_mcp_servers",
|
||||
{},
|
||||
) as cache:
|
||||
cached_server = _cache_temporary_mcp_server(server, ttl_seconds=2)
|
||||
|
||||
assert cached_server is server
|
||||
assert "temp-cache" in cache
|
||||
assert cache["temp-cache"].server is server
|
||||
assert cache["temp-cache"].expires_at > datetime.utcnow()
|
||||
|
||||
def test_get_cached_temporary_mcp_server_prunes_expired_entries(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
_TemporaryMCPServerEntry,
|
||||
get_cached_temporary_mcp_server,
|
||||
)
|
||||
|
||||
server = generate_mock_mcp_server_config_record(server_id="expired")
|
||||
expired_entry = _TemporaryMCPServerEntry(
|
||||
server=server,
|
||||
expires_at=datetime.utcnow() - timedelta(seconds=30),
|
||||
)
|
||||
cache = {"expired": expired_entry}
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._temporary_mcp_servers",
|
||||
cache,
|
||||
):
|
||||
result = get_cached_temporary_mcp_server("expired")
|
||||
|
||||
assert result is None
|
||||
assert "expired" not in cache
|
||||
|
||||
def test_get_cached_temporary_mcp_server_or_404(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
_get_cached_temporary_mcp_server_or_404,
|
||||
)
|
||||
|
||||
server = generate_mock_mcp_server_config_record(server_id="cached")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_cached_temporary_mcp_server",
|
||||
return_value=server,
|
||||
) as get_cached:
|
||||
result = _get_cached_temporary_mcp_server_or_404("cached")
|
||||
|
||||
assert result is server
|
||||
get_cached.assert_called_once_with("cached")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_cached_temporary_mcp_server",
|
||||
return_value=None,
|
||||
):
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_get_cached_temporary_mcp_server_or_404("missing")
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_session_mcp_server_caches_and_redacts_credentials(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
TEMPORARY_MCP_SERVER_TTL_SECONDS,
|
||||
add_session_mcp_server,
|
||||
)
|
||||
|
||||
payload = NewMCPServerRequest(
|
||||
server_id="temp-server",
|
||||
alias="Temporary",
|
||||
url="https://temp.example.com",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
user_auth = generate_mock_user_api_key_auth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_id="admin-user",
|
||||
)
|
||||
inherited_server = MagicMock(
|
||||
authentication_token="token-abc",
|
||||
client_id="client-id",
|
||||
client_secret="client-secret",
|
||||
scopes=["scope1"],
|
||||
)
|
||||
built_server = generate_mock_mcp_server_config_record(server_id="temp-server")
|
||||
mock_manager = MagicMock()
|
||||
mock_manager.get_mcp_server_by_id.return_value = inherited_server
|
||||
mock_manager.build_mcp_server_from_table = AsyncMock(return_value=built_server)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
|
||||
MagicMock(),
|
||||
) as validate_mock, patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager",
|
||||
mock_manager,
|
||||
), patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._cache_temporary_mcp_server",
|
||||
MagicMock(),
|
||||
) as cache_mock:
|
||||
response = await add_session_mcp_server(
|
||||
payload=payload,
|
||||
user_api_key_dict=user_auth,
|
||||
)
|
||||
|
||||
validate_mock.assert_called_once_with(payload)
|
||||
mock_manager.build_mcp_server_from_table.assert_awaited_once()
|
||||
cache_mock.assert_called_once_with(
|
||||
built_server, ttl_seconds=TEMPORARY_MCP_SERVER_TTL_SECONDS
|
||||
)
|
||||
|
||||
args, _ = mock_manager.build_mcp_server_from_table.call_args
|
||||
temp_record = args[0]
|
||||
assert temp_record.credentials == {
|
||||
"auth_value": "token-abc",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "client-secret",
|
||||
"scopes": ["scope1"],
|
||||
}
|
||||
assert response.credentials is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_session_mcp_server_rejects_non_admins(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
add_session_mcp_server,
|
||||
)
|
||||
|
||||
payload = NewMCPServerRequest(
|
||||
alias="Temporary",
|
||||
server_id="temp-server",
|
||||
url="https://temp.example.com",
|
||||
transport=MCPTransport.http,
|
||||
)
|
||||
non_admin = generate_mock_user_api_key_auth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload",
|
||||
MagicMock(),
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await add_session_mcp_server(
|
||||
payload=payload,
|
||||
user_api_key_dict=non_admin,
|
||||
)
|
||||
|
||||
assert "permission" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_authorize_proxies_to_discoverable_endpoint(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
mcp_authorize,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
server = generate_mock_mcp_server_config_record(server_id="server-1")
|
||||
authorize_response = MagicMock()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
|
||||
return_value=server,
|
||||
) as get_server, patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.authorize_with_server",
|
||||
AsyncMock(return_value=authorize_response),
|
||||
) as authorize_mock:
|
||||
result = await mcp_authorize(
|
||||
request=request,
|
||||
server_id="server-1",
|
||||
client_id="client-id",
|
||||
redirect_uri="https://example.com/callback",
|
||||
state="state123",
|
||||
code_challenge="challenge",
|
||||
code_challenge_method="S256",
|
||||
response_type="code",
|
||||
scope="scope1",
|
||||
)
|
||||
|
||||
assert result is authorize_response
|
||||
get_server.assert_called_once_with("server-1")
|
||||
authorize_mock.assert_awaited_once_with(
|
||||
request=request,
|
||||
mcp_server=server,
|
||||
client_id="client-id",
|
||||
redirect_uri="https://example.com/callback",
|
||||
state="state123",
|
||||
code_challenge="challenge",
|
||||
code_challenge_method="S256",
|
||||
response_type="code",
|
||||
scope="scope1",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_token_proxies_to_exchange_endpoint(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
mcp_token,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
server = generate_mock_mcp_server_config_record(server_id="server-1")
|
||||
exchange_response = {"access_token": "token"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
|
||||
return_value=server,
|
||||
) as get_server, patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.exchange_token_with_server",
|
||||
AsyncMock(return_value=exchange_response),
|
||||
) as exchange_mock:
|
||||
result = await mcp_token(
|
||||
request=request,
|
||||
server_id="server-1",
|
||||
grant_type="authorization_code",
|
||||
code="code-123",
|
||||
redirect_uri="https://example.com/callback",
|
||||
client_id="client",
|
||||
client_secret="secret",
|
||||
code_verifier="verifier",
|
||||
)
|
||||
|
||||
assert result is exchange_response
|
||||
get_server.assert_called_once_with("server-1")
|
||||
exchange_mock.assert_awaited_once_with(
|
||||
request=request,
|
||||
mcp_server=server,
|
||||
grant_type="authorization_code",
|
||||
code="code-123",
|
||||
redirect_uri="https://example.com/callback",
|
||||
client_id="client",
|
||||
client_secret="secret",
|
||||
code_verifier="verifier",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_register_proxies_request_body(self):
|
||||
from litellm.proxy.management_endpoints.mcp_management_endpoints import (
|
||||
mcp_register,
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
server = generate_mock_mcp_server_config_record(server_id="server-1")
|
||||
register_response = {"client_id": "generated"}
|
||||
request_body = {
|
||||
"client_name": "LiteLLM",
|
||||
"grant_types": ["authorization_code"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._get_cached_temporary_mcp_server_or_404",
|
||||
return_value=server,
|
||||
) as get_server, patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints._read_request_body",
|
||||
AsyncMock(return_value=request_body),
|
||||
) as read_body, patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.register_client_with_server",
|
||||
AsyncMock(return_value=register_response),
|
||||
) as register_mock:
|
||||
result = await mcp_register(request=request, server_id="server-1")
|
||||
|
||||
assert result is register_response
|
||||
get_server.assert_called_once_with("server-1")
|
||||
read_body.assert_awaited_once_with(request=request)
|
||||
register_mock.assert_awaited_once_with(
|
||||
request=request,
|
||||
mcp_server=server,
|
||||
client_name="LiteLLM",
|
||||
grant_types=["authorization_code"],
|
||||
response_types=["code"],
|
||||
token_endpoint_auth_method="client_secret_basic",
|
||||
fallback_client_id="server-1",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user