From d612d71ef427c2e2de01448d69e81979c7d14f93 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 27 Nov 2025 12:06:53 -0800 Subject: [PATCH] [Feat] Add guardrails for pass through endpoints (#17221) * add PassThroughGuardrailsConfig * init JsonPathExtractor * feat PassthroughGuardrailHandler * feat pt guardrails * pt guardrails * add Pass-Through Endpoint Guardrail Translation * add PassThroughEndpointHandler * execute simple guardrail config and dict settings * TestPassthroughGuardrailHandlerNormalizeConfig * add passthrough_guardrails_config on litellm logging obj * add LiteLLMLoggingObj to base trasaltino * cleaner _get_guardrail_settings * update guardrails settings * docs pt guardrail * docs Guardrails on Pass-Through Endpoints * fix typing * fix typing * test_no_fields_set_sends_full_body * fix typing * Potential fix for code scanning alert no. 3834: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --- .../docs/proxy/pass_through_guardrails.md | 214 +++++++++++ docs/my-website/sidebars.js | 3 +- litellm/litellm_core_utils/litellm_logging.py | 3 + .../chat/guardrail_translation/handler.py | 2 + .../guardrail_translation/base_translation.py | 5 +- .../rerank/guardrail_translation/handler.py | 4 +- .../chat/guardrail_translation/handler.py | 2 + .../guardrail_translation/handler.py | 4 +- .../guardrail_translation/handler.py | 4 +- .../guardrail_translation/handler.py | 2 + .../speech/guardrail_translation/handler.py | 4 +- .../guardrail_translation/handler.py | 4 +- litellm/llms/pass_through/__init__.py | 12 + .../guardrail_translation/README.md | 41 +++ .../guardrail_translation/__init__.py | 15 + .../guardrail_translation/handler.py | 165 +++++++++ litellm/proxy/_types.py | 27 +- .../unified_guardrail/unified_guardrail.py | 2 + .../jsonpath_extractor.py | 95 +++++ .../pass_through_endpoints.py | 72 +++- .../passthrough_guardrails.py | 333 ++++++++++++++++++ litellm/proxy/proxy_config.yaml | 17 +- ..._passthrough_guardrails_field_targeting.py | 119 +++++++ .../test_passthrough_guardrails.py | 263 ++++++++++++++ 24 files changed, 1378 insertions(+), 34 deletions(-) create mode 100644 docs/my-website/docs/proxy/pass_through_guardrails.md create mode 100644 litellm/llms/pass_through/__init__.py create mode 100644 litellm/llms/pass_through/guardrail_translation/README.md create mode 100644 litellm/llms/pass_through/guardrail_translation/__init__.py create mode 100644 litellm/llms/pass_through/guardrail_translation/handler.py create mode 100644 litellm/proxy/pass_through_endpoints/jsonpath_extractor.py create mode 100644 litellm/proxy/pass_through_endpoints/passthrough_guardrails.py create mode 100644 test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails_field_targeting.py create mode 100644 tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails.py diff --git a/docs/my-website/docs/proxy/pass_through_guardrails.md b/docs/my-website/docs/proxy/pass_through_guardrails.md new file mode 100644 index 0000000000..272285e61c --- /dev/null +++ b/docs/my-website/docs/proxy/pass_through_guardrails.md @@ -0,0 +1,214 @@ +# Guardrails on Pass-Through Endpoints + +## Overview + +| Property | Details | +|----------|---------| +| Description | Enable guardrail execution on LiteLLM pass-through endpoints with opt-in activation and automatic inheritance from org/team/key levels | +| Supported Guardrails | All LiteLLM guardrails (Bedrock, Aporia, Lakera, etc.) | +| Default Behavior | Guardrails are **disabled** on pass-through endpoints unless explicitly enabled | + +## Quick Start + +### 1. Define guardrails and pass-through endpoint + +```yaml showLineNumbers title="config.yaml" +guardrails: + - guardrail_name: "pii-guard" + litellm_params: + guardrail: bedrock + mode: pre_call + guardrailIdentifier: "your-guardrail-id" + guardrailVersion: "1" + +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + guardrails: + pii-guard: +``` + +### 2. Start proxy + +```bash +litellm --config config.yaml +``` + +### 3. Test request + +```bash +curl -X POST "http://localhost:4000/v1/rerank" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{ + "model": "rerank-english-v3.0", + "query": "What is the capital of France?", + "documents": ["Paris is the capital of France."] + }' +``` + +--- + +## Opt-In Behavior + +| Configuration | Behavior | +|--------------|----------| +| `guardrails` not set | No guardrails execute (default) | +| `guardrails` set | All org/team/key + pass-through guardrails execute | + +When guardrails are enabled, the system collects and executes: +- Org-level guardrails +- Team-level guardrails +- Key-level guardrails +- Pass-through specific guardrails + +--- + + +## How It Works + +The diagram below shows what happens when a client makes a request to `/special/rerank` - a pass-through endpoint configured with guardrails in your `config.yaml`. + +When guardrails are configured on a pass-through endpoint: +1. **Pre-call guardrails** run on the request before forwarding to the target API +2. If `request_fields` is specified (e.g., `["query"]`), only those fields are sent to the guardrail. Otherwise, the entire request payload is evaluated. +3. The request is forwarded to the target API only if guardrails pass +4. **Post-call guardrails** run on the response from the target API +5. If `response_fields` is specified (e.g., `["results[*].text"]`), only those fields are evaluated. Otherwise, the entire response is checked. + +:::info +If the `guardrails` block is omitted or empty in your pass-through endpoint config, the request skips the guardrail flow entirely and goes directly to the target API. +::: + +```mermaid +sequenceDiagram + participant Client + box rgb(200, 220, 255) LiteLLM Proxy + participant PassThrough as Pass-through Endpoint + participant Guardrails + end + participant Target as Target API (Cohere, etc.) + + Client->>PassThrough: POST /special/rerank + Note over PassThrough,Guardrails: Collect passthrough + org/team/key guardrails + PassThrough->>Guardrails: Run pre_call (request_fields or full payload) + Guardrails-->>PassThrough: ✓ Pass / ✗ Block + PassThrough->>Target: Forward request + Target-->>PassThrough: Response + PassThrough->>Guardrails: Run post_call (response_fields or full payload) + Guardrails-->>PassThrough: ✓ Pass / ✗ Block + PassThrough-->>Client: Return response (or error) +``` + +--- + +## Field-Level Targeting + +Target specific JSON fields instead of the entire request/response payload. + +```yaml showLineNumbers title="config.yaml" +guardrails: + - guardrail_name: "pii-detection" + litellm_params: + guardrail: bedrock + mode: pre_call + guardrailIdentifier: "pii-guard-id" + guardrailVersion: "1" + + - guardrail_name: "content-moderation" + litellm_params: + guardrail: bedrock + mode: post_call + guardrailIdentifier: "content-guard-id" + guardrailVersion: "1" + +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + headers: + Authorization: "bearer os.environ/COHERE_API_KEY" + guardrails: + pii-detection: + request_fields: ["query", "documents[*].text"] + content-moderation: + response_fields: ["results[*].text"] +``` + +### Field Options + +| Field | Description | +|-------|-------------| +| `request_fields` | JSONPath expressions for input (pre_call) | +| `response_fields` | JSONPath expressions for output (post_call) | +| Neither specified | Guardrail runs on entire payload | + +### JSONPath Examples + +| Expression | Matches | +|------------|---------| +| `query` | Single field named `query` | +| `documents[*].text` | All `text` fields in `documents` array | +| `messages[*].content` | All `content` fields in `messages` array | + +--- + +## Configuration Examples + +### Single guardrail on entire payload + +```yaml showLineNumbers title="config.yaml" +guardrails: + - guardrail_name: "pii-detection" + litellm_params: + guardrail: bedrock + mode: pre_call + guardrailIdentifier: "your-id" + guardrailVersion: "1" + +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + guardrails: + pii-detection: +``` + +### Multiple guardrails with mixed settings + +```yaml showLineNumbers title="config.yaml" +guardrails: + - guardrail_name: "pii-detection" + litellm_params: + guardrail: bedrock + mode: pre_call + guardrailIdentifier: "pii-id" + guardrailVersion: "1" + + - guardrail_name: "content-moderation" + litellm_params: + guardrail: bedrock + mode: post_call + guardrailIdentifier: "content-id" + guardrailVersion: "1" + + - guardrail_name: "prompt-injection" + litellm_params: + guardrail: lakera + mode: pre_call + api_key: os.environ/LAKERA_API_KEY + +general_settings: + pass_through_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + guardrails: + pii-detection: + request_fields: ["input", "query"] + content-moderation: + prompt-injection: + request_fields: ["messages[*].content"] +``` diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 2002550a25..917fdca602 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -419,7 +419,8 @@ const sidebars = { ] }, "pass_through/vllm", - "proxy/pass_through" + "proxy/pass_through", + "proxy/pass_through_guardrails" ] }, "rag_ingest", diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 305b7d6ddc..0b0f483ff7 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -374,6 +374,9 @@ class Logging(LiteLLMLoggingBaseClass): # Init Caching related details self.caching_details: Optional[CachingDetails] = None + # Passthrough endpoint guardrails config for field targeting + self.passthrough_guardrails_config: Optional[Dict[str, Any]] = None + self.model_call_details: Dict[str, Any] = { "litellm_trace_id": litellm_trace_id, "litellm_call_id": litellm_call_id, diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index 06a1b92e1b..6aba2947d3 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -41,6 +41,7 @@ class AnthropicMessagesHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input messages by applying guardrails to text content. @@ -145,6 +146,7 @@ class AnthropicMessagesHandler(BaseTranslation): self, response: "AnthropicMessagesResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response by applying guardrails to text content. diff --git a/litellm/llms/base_llm/guardrail_translation/base_translation.py b/litellm/llms/base_llm/guardrail_translation/base_translation.py index 4599af1b74..926ad59cee 100644 --- a/litellm/llms/base_llm/guardrail_translation/base_translation.py +++ b/litellm/llms/base_llm/guardrail_translation/base_translation.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from litellm.integrations.custom_guardrail import CustomGuardrail + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj class BaseTranslation(ABC): @@ -11,6 +12,7 @@ class BaseTranslation(ABC): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None, ) -> Any: pass @@ -19,5 +21,6 @@ class BaseTranslation(ABC): self, response: Any, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None, ) -> Any: pass diff --git a/litellm/llms/cohere/rerank/guardrail_translation/handler.py b/litellm/llms/cohere/rerank/guardrail_translation/handler.py index 0c5e50dc41..a5a5ef68b8 100644 --- a/litellm/llms/cohere/rerank/guardrail_translation/handler.py +++ b/litellm/llms/cohere/rerank/guardrail_translation/handler.py @@ -5,7 +5,7 @@ This module provides guardrail translation support for the rerank endpoint. The handler processes only the 'query' parameter for guardrails. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import verbose_proxy_logger from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation @@ -34,6 +34,7 @@ class CohereRerankHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input query by applying guardrails. @@ -68,6 +69,7 @@ class CohereRerankHandler(BaseTranslation): self, response: "RerankResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response - not applicable for rerank. diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index b01f9f1b98..2a421a8283 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -42,6 +42,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input messages by applying guardrails to text content. @@ -148,6 +149,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation): self, response: "ModelResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response by applying guardrails to text content. diff --git a/litellm/llms/openai/completion/guardrail_translation/handler.py b/litellm/llms/openai/completion/guardrail_translation/handler.py index b5db730620..5a38d04d75 100644 --- a/litellm/llms/openai/completion/guardrail_translation/handler.py +++ b/litellm/llms/openai/completion/guardrail_translation/handler.py @@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text completion The handler processes the 'prompt' parameter for guardrails. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import verbose_proxy_logger from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation @@ -32,6 +32,7 @@ class OpenAITextCompletionHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input prompt by applying guardrails to text content. @@ -100,6 +101,7 @@ class OpenAITextCompletionHandler(BaseTranslation): self, response: "TextCompletionResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response by applying guardrails to completion text. diff --git a/litellm/llms/openai/image_generation/guardrail_translation/handler.py b/litellm/llms/openai/image_generation/guardrail_translation/handler.py index 5fcb5278f0..de6bca8e57 100644 --- a/litellm/llms/openai/image_generation/guardrail_translation/handler.py +++ b/litellm/llms/openai/image_generation/guardrail_translation/handler.py @@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's image generation The handler processes the 'prompt' parameter for guardrails. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import verbose_proxy_logger from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation @@ -31,6 +31,7 @@ class OpenAIImageGenerationHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input prompt by applying guardrails to text content. @@ -72,6 +73,7 @@ class OpenAIImageGenerationHandler(BaseTranslation): self, response: "ImageResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response - typically not needed for image generation. diff --git a/litellm/llms/openai/responses/guardrail_translation/handler.py b/litellm/llms/openai/responses/guardrail_translation/handler.py index fdac13176b..489a89c60c 100644 --- a/litellm/llms/openai/responses/guardrail_translation/handler.py +++ b/litellm/llms/openai/responses/guardrail_translation/handler.py @@ -56,6 +56,7 @@ class OpenAIResponsesHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input by applying guardrails to text content. @@ -177,6 +178,7 @@ class OpenAIResponsesHandler(BaseTranslation): self, response: "ResponsesAPIResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output response by applying guardrails to text content. diff --git a/litellm/llms/openai/speech/guardrail_translation/handler.py b/litellm/llms/openai/speech/guardrail_translation/handler.py index aa049801d1..47df79833b 100644 --- a/litellm/llms/openai/speech/guardrail_translation/handler.py +++ b/litellm/llms/openai/speech/guardrail_translation/handler.py @@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text-to-speech e The handler processes the 'input' text parameter (output is audio, so no text to guardrail). """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import verbose_proxy_logger from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation @@ -30,6 +30,7 @@ class OpenAITextToSpeechHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input text by applying guardrails. @@ -72,6 +73,7 @@ class OpenAITextToSpeechHandler(BaseTranslation): self, response: "HttpxBinaryResponseContent", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output - not applicable for text-to-speech. diff --git a/litellm/llms/openai/transcriptions/guardrail_translation/handler.py b/litellm/llms/openai/transcriptions/guardrail_translation/handler.py index 22b93251be..51f50c9180 100644 --- a/litellm/llms/openai/transcriptions/guardrail_translation/handler.py +++ b/litellm/llms/openai/transcriptions/guardrail_translation/handler.py @@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's audio transcript The handler processes the output transcribed text (input is audio, so no text to guardrail). """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional from litellm._logging import verbose_proxy_logger from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation @@ -30,6 +30,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation): self, data: dict, guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process input - not applicable for audio transcription. @@ -54,6 +55,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation): self, response: "TranscriptionResponse", guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional[Any] = None, ) -> Any: """ Process output transcription by applying guardrails to transcribed text. diff --git a/litellm/llms/pass_through/__init__.py b/litellm/llms/pass_through/__init__.py new file mode 100644 index 0000000000..803772f14d --- /dev/null +++ b/litellm/llms/pass_through/__init__.py @@ -0,0 +1,12 @@ +""" +Pass-Through Endpoint Guardrail Translation + +This module exists here (under litellm/llms/) so it can be auto-discovered by +load_guardrail_translation_mappings() which scans for guardrail_translation +directories under litellm/llms/. + +The main passthrough endpoint implementation is in: + litellm/proxy/pass_through_endpoints/ + +See guardrail_translation/README.md for more details. +""" diff --git a/litellm/llms/pass_through/guardrail_translation/README.md b/litellm/llms/pass_through/guardrail_translation/README.md new file mode 100644 index 0000000000..db4c0704e7 --- /dev/null +++ b/litellm/llms/pass_through/guardrail_translation/README.md @@ -0,0 +1,41 @@ +# Pass-Through Endpoint Guardrail Translation + +## Why This Exists Here + +This module is located under `litellm/llms/` (instead of with the main passthrough code) because: + +1. **Auto-discovery**: The `load_guardrail_translation_mappings()` function in `litellm/llms/__init__.py` scans for `guardrail_translation/` directories under `litellm/llms/` +2. **Consistency**: All other guardrail translation handlers follow this pattern (e.g., `openai/chat/guardrail_translation/`, `anthropic/chat/guardrail_translation/`) + +## Main Passthrough Implementation + +The main passthrough endpoint implementation is in: + +``` +litellm/proxy/pass_through_endpoints/ +├── pass_through_endpoints.py # Core passthrough routing logic +├── passthrough_guardrails.py # Guardrail collection and field targeting +├── jsonpath_extractor.py # JSONPath field extraction utility +└── ... +``` + +## What This Handler Does + +The `PassThroughEndpointHandler` enables guardrails to run on passthrough endpoint requests by: + +1. **Field Targeting**: Extracts specific fields from the request/response using JSONPath expressions configured in `request_fields` / `response_fields` +2. **Full Payload Fallback**: If no field targeting is configured, processes the entire payload +3. **Config Access**: Uses `get_passthrough_guardrails_config()` / `set_passthrough_guardrails_config()` helpers to access the passthrough guardrails configuration stored in request metadata + +## Example Config + +```yaml +passthrough_endpoints: + - path: "/v1/rerank" + target: "https://api.cohere.com/v1/rerank" + guardrails: + bedrock-pre-guard: + request_fields: ["query", "documents[*].text"] + response_fields: ["results[*].text"] +``` + diff --git a/litellm/llms/pass_through/guardrail_translation/__init__.py b/litellm/llms/pass_through/guardrail_translation/__init__.py new file mode 100644 index 0000000000..db69c8e378 --- /dev/null +++ b/litellm/llms/pass_through/guardrail_translation/__init__.py @@ -0,0 +1,15 @@ +"""Pass-Through Endpoint guardrail translation handler.""" + +from litellm.llms.pass_through.guardrail_translation.handler import ( + PassThroughEndpointHandler, +) +from litellm.types.utils import CallTypes + +guardrail_translation_mappings = { + CallTypes.pass_through: PassThroughEndpointHandler, +} + +__all__ = [ + "guardrail_translation_mappings", + "PassThroughEndpointHandler", +] diff --git a/litellm/llms/pass_through/guardrail_translation/handler.py b/litellm/llms/pass_through/guardrail_translation/handler.py new file mode 100644 index 0000000000..5ff9fd25c5 --- /dev/null +++ b/litellm/llms/pass_through/guardrail_translation/handler.py @@ -0,0 +1,165 @@ +""" +Pass-Through Endpoint Message Handler for Unified Guardrails + +This module provides a handler for passthrough endpoint requests. +It uses the field targeting configuration from litellm_logging_obj +to extract specific fields for guardrail processing. +""" + +from typing import TYPE_CHECKING, Any, List, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation +from litellm.proxy._types import PassThroughGuardrailSettings + +if TYPE_CHECKING: + from litellm.integrations.custom_guardrail import CustomGuardrail + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + +class PassThroughEndpointHandler(BaseTranslation): + """ + Handler for processing passthrough endpoint requests with guardrails. + + Uses passthrough_guardrails_config from litellm_logging_obj + to determine which fields to extract for guardrail processing. + """ + + def _get_guardrail_settings( + self, + litellm_logging_obj: Optional["LiteLLMLoggingObj"], + guardrail_name: Optional[str], + ) -> Optional[PassThroughGuardrailSettings]: + """ + Get the guardrail settings for a specific guardrail from logging_obj. + """ + from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( + PassthroughGuardrailHandler, + ) + + if litellm_logging_obj is None: + return None + + passthrough_config = getattr( + litellm_logging_obj, "passthrough_guardrails_config", None + ) + if not passthrough_config or not guardrail_name: + return None + + return PassthroughGuardrailHandler.get_settings( + passthrough_config, guardrail_name + ) + + def _extract_text_for_guardrail( + self, + data: dict, + field_expressions: Optional[List[str]], + ) -> str: + """ + Extract text from data for guardrail processing. + + If field_expressions provided, extracts only those fields. + Otherwise, returns the full payload as JSON. + """ + from litellm.proxy.pass_through_endpoints.jsonpath_extractor import ( + JsonPathExtractor, + ) + + if field_expressions: + text = JsonPathExtractor.extract_fields( + data=data, + jsonpath_expressions=field_expressions, + ) + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: Extracted targeted fields: %s", + text[:200] if text else None, + ) + return text + + # Use entire payload, excluding internal fields + from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + + payload_to_check = { + k: v + for k, v in data.items() + if not k.startswith("_") and k not in ("metadata", "litellm_logging_obj") + } + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: Using full payload for guardrail" + ) + return safe_dumps(payload_to_check) + + async def process_input_messages( + self, + data: dict, + guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None, + ) -> Any: + """ + Process input by applying guardrails to targeted fields or full payload. + """ + guardrail_name = guardrail_to_apply.guardrail_name + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: Processing input for guardrail=%s", + guardrail_name, + ) + + # Get field targeting settings for this guardrail + settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name) + field_expressions = settings.request_fields if settings else None + + # Extract text to check + text_to_check = self._extract_text_for_guardrail(data, field_expressions) + + if not text_to_check: + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: No text to check, skipping guardrail" + ) + return data + + # Apply guardrail + await guardrail_to_apply.apply_guardrail( + text=text_to_check, + request_data=data, + ) + + return data + + async def process_output_response( + self, + response: Any, + guardrail_to_apply: "CustomGuardrail", + litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None, + ) -> Any: + """ + Process output response by applying guardrails to targeted fields. + """ + if not isinstance(response, dict): + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: Response is not a dict, skipping" + ) + return response + + guardrail_name = guardrail_to_apply.guardrail_name + verbose_proxy_logger.debug( + "PassThroughEndpointHandler: Processing output for guardrail=%s", + guardrail_name, + ) + + # Get field targeting settings for this guardrail + settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name) + field_expressions = settings.response_fields if settings else None + + # Extract text to check + text_to_check = self._extract_text_for_guardrail(response, field_expressions) + + if not text_to_check: + return response + + # Apply guardrail + await guardrail_to_apply.apply_guardrail( + text=text_to_check, + request_data=response, + ) + + return response diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 7e30079e78..9e915d4bc5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1693,27 +1693,26 @@ class DynamoDBArgs(LiteLLMPydanticObjectBase): assume_role_aws_session_name: Optional[str] = None -class PassThroughGuardrailConfig(LiteLLMPydanticObjectBase): +class PassThroughGuardrailSettings(LiteLLMPydanticObjectBase): """ - Configuration for guardrails on passthrough endpoints. + Settings for a specific guardrail on a passthrough endpoint. - Passthrough endpoints are opt-in only for guardrails. Guardrails configured at - org/team/key levels will NOT execute unless explicitly enabled here. + Allows field-level targeting for guardrail execution. """ - enabled: bool = Field( - default=False, - description="Whether to execute guardrails for this passthrough endpoint. When True, all org/team/key level guardrails will execute along with any passthrough-specific guardrails. When False (default), NO guardrails execute.", - ) - specific: Optional[List[str]] = Field( + request_fields: Optional[List[str]] = Field( default=None, - description="Optional list of guardrail names that are specific to this passthrough endpoint. These will execute in addition to org/team/key level guardrails when enabled=True.", + description="JSONPath expressions for input field targeting (pre_call). Examples: 'query', 'documents[*].text', 'messages[*].content'. If not specified, guardrail runs on entire request payload.", ) - target_fields: Optional[List[str]] = Field( + response_fields: Optional[List[str]] = Field( default=None, - description="Optional list of JSON paths to target specific fields for guardrail execution. Examples: 'messages[*].content', 'input', 'messages[?(@.role=='user')].content'. If not specified, guardrails execute on entire payload.", + description="JSONPath expressions for output field targeting (post_call). Examples: 'results[*].text', 'output'. If not specified, guardrail runs on entire response payload.", ) +# Type alias for the guardrails dict: guardrail_name -> settings (or None for defaults) +PassThroughGuardrailsConfig = Dict[str, Optional[PassThroughGuardrailSettings]] + + class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase): id: Optional[str] = Field( default=None, @@ -1739,9 +1738,9 @@ class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase): default=False, description="Whether authentication is required for the pass-through endpoint. If True, requests to the endpoint will require a valid LiteLLM API key.", ) - guardrails: Optional[PassThroughGuardrailConfig] = Field( + guardrails: Optional[PassThroughGuardrailsConfig] = Field( default=None, - description="Guardrail configuration for this passthrough endpoint. When enabled, org/team/key level guardrails will execute along with any passthrough-specific guardrails. Defaults to disabled (no guardrails execute).", + description="Guardrails configuration for this passthrough endpoint. Dict keys are guardrail names, values are optional settings for field targeting. When set, all org/team/key level guardrails will also execute. Defaults to None (no guardrails execute).", ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py index 824ed4e0b0..9ee1eb8671 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py @@ -86,6 +86,7 @@ class UnifiedLLMGuardrails(CustomLogger): data = await endpoint_translation.process_input_messages( data=data, guardrail_to_apply=guardrail_to_apply, + litellm_logging_obj=data.get("litellm_logging_obj"), ) # Add guardrail to applied guardrails header @@ -148,6 +149,7 @@ class UnifiedLLMGuardrails(CustomLogger): response = await endpoint_translation.process_output_response( response=response, # type: ignore guardrail_to_apply=guardrail_to_apply, + litellm_logging_obj=data.get("litellm_logging_obj"), ) # Add guardrail to applied guardrails header add_guardrail_to_applied_guardrails_header( diff --git a/litellm/proxy/pass_through_endpoints/jsonpath_extractor.py b/litellm/proxy/pass_through_endpoints/jsonpath_extractor.py new file mode 100644 index 0000000000..fde2553be4 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/jsonpath_extractor.py @@ -0,0 +1,95 @@ +""" +JSONPath Extractor Module + +Extracts field values from data using simple JSONPath-like expressions. +""" + +from typing import Any, List, Union + +from litellm._logging import verbose_proxy_logger + + +class JsonPathExtractor: + """Extracts field values from data using JSONPath-like expressions.""" + + @staticmethod + def extract_fields( + data: dict, + jsonpath_expressions: List[str], + ) -> str: + """ + Extract field values from data using JSONPath-like expressions. + + Supports simple expressions like: + - "query" -> data["query"] + - "documents[*].text" -> all text fields from documents array + - "messages[*].content" -> all content fields from messages array + + Returns concatenated string of all extracted values. + """ + extracted_values: List[str] = [] + + for expr in jsonpath_expressions: + try: + value = JsonPathExtractor.evaluate(data, expr) + if value: + if isinstance(value, list): + extracted_values.extend([str(v) for v in value if v]) + else: + extracted_values.append(str(value)) + except Exception as e: + verbose_proxy_logger.debug( + "Failed to extract field %s: %s", expr, str(e) + ) + + return "\n".join(extracted_values) + + @staticmethod + def evaluate(data: dict, expr: str) -> Union[str, List[str], None]: + """ + Evaluate a simple JSONPath-like expression. + + Supports: + - Simple key: "query" -> data["query"] + - Nested key: "foo.bar" -> data["foo"]["bar"] + - Array wildcard: "items[*].text" -> [item["text"] for item in data["items"]] + """ + if not expr or not data: + return None + + parts = expr.replace("[*]", ".[*]").split(".") + current: Any = data + + for i, part in enumerate(parts): + if current is None: + return None + + if part == "[*]": + # Wildcard - current should be a list + if not isinstance(current, list): + return None + + # Get remaining path + remaining_path = ".".join(parts[i + 1:]) + if not remaining_path: + return current + + # Recursively evaluate remaining path for each item + results = [] + for item in current: + if isinstance(item, dict): + result = JsonPathExtractor.evaluate(item, remaining_path) + if result: + if isinstance(result, list): + results.extend(result) + else: + results.append(result) + return results if results else None + + elif isinstance(current, dict): + current = current.get(part) + else: + return None + + return current + diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 644b5ce929..df4452f726 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -599,6 +599,7 @@ async def pass_through_request( # noqa: PLR0915 stream: Optional[bool] = None, cost_per_request: Optional[float] = None, custom_llm_provider: Optional[str] = None, + guardrails_config: Optional[dict] = None, ): """ Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called @@ -614,8 +615,13 @@ async def pass_through_request( # noqa: PLR0915 query_params: The query params stream: Whether to stream the response cost_per_request: Optional field - cost per request to the target endpoint + custom_llm_provider: Optional field - custom LLM provider for the endpoint + guardrails_config: Optional field - guardrails configuration for passthrough endpoint """ from litellm.litellm_core_utils.litellm_logging import Logging + from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( + PassthroughGuardrailHandler, + ) from litellm.proxy.proxy_server import proxy_logging_obj ######################################################### @@ -664,6 +670,45 @@ async def pass_through_request( # noqa: PLR0915 ) ) + ### COLLECT GUARDRAILS FOR PASSTHROUGH ENDPOINT ### + # Passthrough endpoints are opt-in only for guardrails + # When enabled, collect guardrails from org/team/key levels + passthrough-specific + guardrails_to_run = PassthroughGuardrailHandler.collect_guardrails( + user_api_key_dict=user_api_key_dict, + passthrough_guardrails_config=guardrails_config, + ) + + # Add guardrails to metadata if any should run + if guardrails_to_run and len(guardrails_to_run) > 0: + if _parsed_body is None: + _parsed_body = {} + if "metadata" not in _parsed_body: + _parsed_body["metadata"] = {} + _parsed_body["metadata"]["guardrails"] = guardrails_to_run + verbose_proxy_logger.debug( + f"Added guardrails to passthrough request metadata: {guardrails_to_run}" + ) + + ## LOGGING OBJECT ## - initialize before pre_call_hook so guardrails can access it + start_time = datetime.now() + logging_obj = Logging( + model="unknown", + messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], + stream=False, + call_type="pass_through_endpoint", + start_time=start_time, + litellm_call_id=litellm_call_id, + function_id="1245", + ) + + # Store passthrough guardrails config on logging_obj for field targeting + logging_obj.passthrough_guardrails_config = guardrails_config + + # Store logging_obj in data so guardrails can access it + if _parsed_body is None: + _parsed_body = {} + _parsed_body["litellm_logging_obj"] = logging_obj + ### CALL HOOKS ### - modify incoming data / reject request before calling the model _parsed_body = await proxy_logging_obj.pre_call_hook( user_api_key_dict=user_api_key_dict, @@ -675,18 +720,6 @@ async def pass_through_request( # noqa: PLR0915 params={"timeout": 600}, ) async_client = async_client_obj.client - - # create logging object - start_time = datetime.now() - logging_obj = Logging( - model="unknown", - messages=[{"role": "user", "content": safe_dumps(_parsed_body)}], - stream=False, - call_type="pass_through_endpoint", - start_time=start_time, - litellm_call_id=litellm_call_id, - function_id="1245", - ) passthrough_logging_payload = PassthroughStandardLoggingPayload( url=str(url), request_body=_parsed_body, @@ -1011,6 +1044,7 @@ def create_pass_through_route( custom_llm_provider: Optional[str] = None, is_streaming_request: Optional[bool] = False, query_params: Optional[dict] = None, + guardrails: Optional[Dict[str, Any]] = None, ): # check if target is an adapter.py or a url from litellm._uuid import uuid @@ -1079,6 +1113,7 @@ def create_pass_through_route( "forward_headers": _forward_headers, "merge_query_params": _merge_query_params, "cost_per_request": cost_per_request, + "guardrails": None, } if passthrough_params is not None: @@ -1096,6 +1131,7 @@ def create_pass_through_route( param_cost_per_request = target_params.get( "cost_per_request", cost_per_request ) + param_guardrails = target_params.get("guardrails", None) # Construct the full target URL with subpath if needed full_target = ( @@ -1135,6 +1171,7 @@ def create_pass_through_route( custom_body=final_custom_body, cost_per_request=cast(Optional[float], param_cost_per_request), custom_llm_provider=custom_llm_provider, + guardrails_config=param_guardrails, ) return endpoint_func @@ -1769,6 +1806,7 @@ class InitPassThroughEndpointHelpers: dependencies: Optional[List], cost_per_request: Optional[float], endpoint_id: str, + guardrails: Optional[dict] = None, ): """Add exact path route for pass-through endpoint""" route_key = f"{endpoint_id}:exact:{path}" @@ -1799,6 +1837,7 @@ class InitPassThroughEndpointHelpers: merge_query_params, dependencies, cost_per_request=cost_per_request, + guardrails=guardrails, ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=dependencies, @@ -1817,6 +1856,7 @@ class InitPassThroughEndpointHelpers: "merge_query_params": merge_query_params, "dependencies": dependencies, "cost_per_request": cost_per_request, + "guardrails": guardrails, }, } @@ -1831,6 +1871,7 @@ class InitPassThroughEndpointHelpers: dependencies: Optional[List], cost_per_request: Optional[float], endpoint_id: str, + guardrails: Optional[dict] = None, ): """Add wildcard route for sub-paths""" wildcard_path = f"{path}/{{subpath:path}}" @@ -1863,6 +1904,7 @@ class InitPassThroughEndpointHelpers: dependencies, include_subpath=True, cost_per_request=cost_per_request, + guardrails=guardrails, ), methods=["GET", "POST", "PUT", "DELETE", "PATCH"], dependencies=dependencies, @@ -1881,6 +1923,7 @@ class InitPassThroughEndpointHelpers: "merge_query_params": merge_query_params, "dependencies": dependencies, "cost_per_request": cost_per_request, + "guardrails": guardrails, }, } @@ -2057,6 +2100,9 @@ async def initialize_pass_through_endpoints( if _target is None: continue + # Get guardrails config if present + _guardrails = endpoint.get("guardrails", None) + # Add exact path route verbose_proxy_logger.debug( "Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id @@ -2071,6 +2117,7 @@ async def initialize_pass_through_endpoints( dependencies=_dependencies, cost_per_request=endpoint.get("cost_per_request", None), endpoint_id=endpoint_id, + guardrails=_guardrails, ) visited_endpoints.add(f"{endpoint_id}:exact:{_path}") @@ -2087,6 +2134,7 @@ async def initialize_pass_through_endpoints( dependencies=_dependencies, cost_per_request=endpoint.get("cost_per_request", None), endpoint_id=endpoint_id, + guardrails=_guardrails, ) visited_endpoints.add(f"{endpoint_id}:subpath:{_path}") diff --git a/litellm/proxy/pass_through_endpoints/passthrough_guardrails.py b/litellm/proxy/pass_through_endpoints/passthrough_guardrails.py new file mode 100644 index 0000000000..cec20f3e04 --- /dev/null +++ b/litellm/proxy/pass_through_endpoints/passthrough_guardrails.py @@ -0,0 +1,333 @@ +""" +Passthrough Guardrails Helper Module + +Handles guardrail execution for passthrough endpoints with: +- Opt-in model (guardrails only run when explicitly configured) +- Field-level targeting using JSONPath expressions +- Automatic inheritance from org/team/key levels when enabled +""" + +from typing import Any, Dict, List, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import ( + PassThroughGuardrailsConfig, + PassThroughGuardrailSettings, + UserAPIKeyAuth, +) +from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor + +# Type for raw guardrails config input (before normalization) +# Can be a list of names or a dict with settings +PassThroughGuardrailsConfigInput = Union[ + List[str], # Simple list: ["guard-1", "guard-2"] + PassThroughGuardrailsConfig, # Dict: {"guard-1": {"request_fields": [...]}} +] + + +class PassthroughGuardrailHandler: + """ + Handles guardrail execution for passthrough endpoints. + + Passthrough endpoints use an opt-in model for guardrails: + - Guardrails only run when explicitly configured on the endpoint + - Supports field-level targeting using JSONPath expressions + - Automatically inherits org/team/key level guardrails when enabled + + Guardrails can be specified as: + - List format (simple): ["guardrail-1", "guardrail-2"] + - Dict format (with settings): {"guardrail-1": {"request_fields": ["query"]}} + """ + + @staticmethod + def normalize_config( + guardrails_config: Optional[PassThroughGuardrailsConfigInput], + ) -> Optional[PassThroughGuardrailsConfig]: + """ + Normalize guardrails config to dict format. + + Accepts: + - List of guardrail names: ["g1", "g2"] -> {"g1": None, "g2": None} + - Dict with settings: {"g1": {"request_fields": [...]}} + - None: returns None + """ + if guardrails_config is None: + return None + + # Already a dict - return as-is + if isinstance(guardrails_config, dict): + return guardrails_config + + # List of guardrail names - convert to dict + if isinstance(guardrails_config, list): + return {name: None for name in guardrails_config} + + verbose_proxy_logger.debug( + "Passthrough guardrails config is not a dict or list, got: %s", + type(guardrails_config), + ) + return None + + @staticmethod + def is_enabled( + guardrails_config: Optional[PassThroughGuardrailsConfigInput], + ) -> bool: + """ + Check if guardrails are enabled for a passthrough endpoint. + + Passthrough endpoints are opt-in only - guardrails only run when + the guardrails config is set with at least one guardrail. + """ + normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config) + if normalized is None: + return False + return len(normalized) > 0 + + @staticmethod + def get_guardrail_names( + guardrails_config: Optional[PassThroughGuardrailsConfigInput], + ) -> List[str]: + """Get the list of guardrail names configured for a passthrough endpoint.""" + normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config) + if normalized is None: + return [] + return list(normalized.keys()) + + @staticmethod + def get_settings( + guardrails_config: Optional[PassThroughGuardrailsConfigInput], + guardrail_name: str, + ) -> Optional[PassThroughGuardrailSettings]: + """Get settings for a specific guardrail from the passthrough config.""" + normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config) + if normalized is None: + return None + + settings = normalized.get(guardrail_name) + if settings is None: + return None + + if isinstance(settings, dict): + return PassThroughGuardrailSettings(**settings) + + return settings + + @staticmethod + def prepare_input( + request_data: dict, + guardrail_settings: Optional[PassThroughGuardrailSettings], + ) -> str: + """ + Prepare input text for guardrail execution based on field targeting settings. + + If request_fields is specified, extracts only those fields. + Otherwise, uses the entire request payload as text. + """ + if guardrail_settings is None or guardrail_settings.request_fields is None: + from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + return safe_dumps(request_data) + + return JsonPathExtractor.extract_fields( + data=request_data, + jsonpath_expressions=guardrail_settings.request_fields, + ) + + @staticmethod + def prepare_output( + response_data: dict, + guardrail_settings: Optional[PassThroughGuardrailSettings], + ) -> str: + """ + Prepare output text for guardrail execution based on field targeting settings. + + If response_fields is specified, extracts only those fields. + Otherwise, uses the entire response payload as text. + """ + if guardrail_settings is None or guardrail_settings.response_fields is None: + from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + return safe_dumps(response_data) + + return JsonPathExtractor.extract_fields( + data=response_data, + jsonpath_expressions=guardrail_settings.response_fields, + ) + + @staticmethod + async def execute( + request_data: dict, + user_api_key_dict: UserAPIKeyAuth, + guardrails_config: Optional[PassThroughGuardrailsConfig], + event_type: str = "pre_call", + ) -> dict: + """ + Execute guardrails for a passthrough endpoint. + + This is the main entry point for passthrough guardrail execution. + + Args: + request_data: The request payload + user_api_key_dict: User API key authentication info + guardrails_config: Passthrough-specific guardrails configuration + event_type: "pre_call" for request, "post_call" for response + + Returns: + The potentially modified request_data + + Raises: + HTTPException if a guardrail blocks the request + """ + if not PassthroughGuardrailHandler.is_enabled(guardrails_config): + verbose_proxy_logger.debug( + "Passthrough guardrails not enabled, skipping guardrail execution" + ) + return request_data + + guardrail_names = PassthroughGuardrailHandler.get_guardrail_names( + guardrails_config + ) + verbose_proxy_logger.debug( + "Executing passthrough guardrails: %s", guardrail_names + ) + + # Add to request metadata so guardrails know which to run + from litellm.proxy.pass_through_endpoints.passthrough_context import ( + set_passthrough_guardrails_config, + ) + + if "metadata" not in request_data: + request_data["metadata"] = {} + + # Set guardrails in metadata using dict format for compatibility + request_data["metadata"]["guardrails"] = { + name: True for name in guardrail_names + } + + # Store passthrough guardrails config in request-scoped context + set_passthrough_guardrails_config(guardrails_config) + + return request_data + + @staticmethod + def collect_guardrails( + user_api_key_dict: UserAPIKeyAuth, + passthrough_guardrails_config: Optional[PassThroughGuardrailsConfigInput], + ) -> Optional[Dict[str, bool]]: + """ + Collect guardrails for a passthrough endpoint. + + Passthrough endpoints are opt-in only for guardrails. Guardrails only run when + the guardrails config is set with at least one guardrail. + + Accepts both list and dict formats: + - List: ["guardrail-1", "guardrail-2"] + - Dict: {"guardrail-1": {"request_fields": [...]}} + + When enabled, this function collects: + - Passthrough-specific guardrails from the config + - Org/team/key level guardrails (automatic inheritance when passthrough is enabled) + + Args: + user_api_key_dict: User API key authentication info + passthrough_guardrails_config: List or Dict of guardrail names/settings + + Returns: + Dict of guardrail names to run (format: {guardrail_name: True}), or None + """ + from litellm.proxy.litellm_pre_call_utils import ( + _add_guardrails_from_key_or_team_metadata, + ) + + # Normalize config to dict format (handles both list and dict) + normalized_config = PassthroughGuardrailHandler.normalize_config( + passthrough_guardrails_config + ) + + if normalized_config is None: + verbose_proxy_logger.debug( + "Passthrough guardrails not configured, skipping guardrail collection" + ) + return None + + if len(normalized_config) == 0: + verbose_proxy_logger.debug( + "Passthrough guardrails config is empty, skipping" + ) + return None + + # Passthrough is enabled - collect guardrails + guardrails_to_run: Dict[str, bool] = {} + + # Add passthrough-specific guardrails + for guardrail_name in normalized_config.keys(): + guardrails_to_run[guardrail_name] = True + verbose_proxy_logger.debug( + "Added passthrough-specific guardrail" + ) + + # Add org/team/key level guardrails using shared helper + temp_data: Dict[str, Any] = {"metadata": {}} + _add_guardrails_from_key_or_team_metadata( + key_metadata=user_api_key_dict.metadata, + team_metadata=user_api_key_dict.team_metadata, + data=temp_data, + metadata_variable_name="metadata", + ) + + # Merge inherited guardrails into guardrails_to_run + inherited_guardrails = temp_data["metadata"].get("guardrails", []) + for guardrail_name in inherited_guardrails: + if guardrail_name not in guardrails_to_run: + guardrails_to_run[guardrail_name] = True + verbose_proxy_logger.debug( + "Added inherited guardrail (key/team level)" + ) + + verbose_proxy_logger.debug( + "Collected total guardrails for passthrough endpoint: %d", + len(guardrails_to_run), + ) + + return guardrails_to_run if guardrails_to_run else None + + @staticmethod + def get_field_targeted_text( + data: dict, + guardrail_name: str, + is_request: bool = True, + ) -> Optional[str]: + """ + Get the text to check for a guardrail, respecting field targeting settings. + + Called by guardrail hooks to get the appropriate text based on + passthrough field targeting configuration. + + Args: + data: The request/response data dict + guardrail_name: Name of the guardrail being executed + is_request: True for request (pre_call), False for response (post_call) + + Returns: + The text to check, or None to use default behavior + """ + from litellm.proxy.pass_through_endpoints.passthrough_context import ( + get_passthrough_guardrails_config, + ) + + passthrough_config = get_passthrough_guardrails_config() + if passthrough_config is None: + return None + + settings = PassthroughGuardrailHandler.get_settings( + passthrough_config, guardrail_name + ) + if settings is None: + return None + + if is_request: + if settings.request_fields: + return JsonPathExtractor.extract_fields(data, settings.request_fields) + else: + if settings.response_fields: + return JsonPathExtractor.extract_fields(data, settings.response_fields) + + return None diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 26e867dc33..098cdb80e0 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -3,7 +3,14 @@ model_list: litellm_params: model: bedrock/openai/arn:aws:bedrock:us-east-1:046319184608:imported-model/0m2lasirsp6z - +guardrails: + - guardrail_name: "bedrock-pre-guard" + litellm_params: + guardrail: bedrock + mode: "pre_call" + guardrailIdentifier: ff6ujrregl1q + guardrailVersion: "DRAFT" + # like MCPs/vector stores search_tools: @@ -40,6 +47,14 @@ litellm_settings: general_settings: store_prompts_in_spend_logs: True + pass_through_endpoints: + - path: "/special/rerank" + target: "https://api.cohere.com/v1/rerank" + headers: + Authorization: "Bearer os.environ/COHERE_API_KEY" + guardrails: + bedrock-pre-guard: + request_fields: ["documents[*].text"] vector_store_registry: diff --git a/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails_field_targeting.py b/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails_field_targeting.py new file mode 100644 index 0000000000..05f367ff13 --- /dev/null +++ b/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails_field_targeting.py @@ -0,0 +1,119 @@ +""" +Test passthrough guardrails field-level targeting. + +Tests that request_fields and response_fields correctly extract +and send only specified fields to the guardrail. +""" + +import json +import os +import sys +from unittest.mock import AsyncMock, MagicMock + +import pytest + +sys.path.insert(0, os.path.abspath("../..")) + +from litellm.proxy._types import PassThroughGuardrailSettings +from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( + PassthroughGuardrailHandler, +) + + +def test_no_fields_set_sends_full_body(): + """ + Test that when no request_fields are set, the entire request body + is JSON dumped and sent to the guardrail. + """ + request_data = { + "model": "rerank-english-v3.0", + "query": "What is coffee?", + "documents": [ + {"text": "Paris is the capital of France."}, + {"text": "Coffee is a brewed drink."} + ] + } + + # No guardrail settings means full body + result = PassthroughGuardrailHandler.prepare_input( + request_data=request_data, + guardrail_settings=None + ) + + # Result should be JSON string of full request + assert isinstance(result, str) + result_dict = json.loads(result) + + # Should contain all fields + assert "model" in result_dict + assert "query" in result_dict + assert "documents" in result_dict + assert result_dict["query"] == "What is coffee?" + assert len(result_dict["documents"]) == 2 + + +def test_request_fields_query_only(): + """ + Test that when request_fields is set to ["query"], only the query field + is extracted and sent to the guardrail. + """ + request_data = { + "model": "rerank-english-v3.0", + "query": "What is coffee?", + "documents": [ + {"text": "Paris is the capital of France."}, + {"text": "Coffee is a brewed drink."} + ] + } + + # Set request_fields to only extract query + guardrail_settings = PassThroughGuardrailSettings( + request_fields=["query"] + ) + + result = PassthroughGuardrailHandler.prepare_input( + request_data=request_data, + guardrail_settings=guardrail_settings + ) + + # Result should only contain query + assert isinstance(result, str) + assert "What is coffee?" in result + + # Should NOT contain documents + assert "Paris is the capital" not in result + assert "Coffee is a brewed drink" not in result + + +def test_request_fields_documents_wildcard(): + """ + Test that when request_fields is set to ["documents[*]"], only the documents + array is extracted and sent to the guardrail. + """ + request_data = { + "model": "rerank-english-v3.0", + "query": "What is coffee?", + "documents": [ + {"text": "Paris is the capital of France."}, + {"text": "Coffee is a brewed drink."} + ] + } + + # Set request_fields to extract documents array + guardrail_settings = PassThroughGuardrailSettings( + request_fields=["documents[*]"] + ) + + result = PassthroughGuardrailHandler.prepare_input( + request_data=request_data, + guardrail_settings=guardrail_settings + ) + + # Result should contain documents + assert isinstance(result, str) + assert "Paris is the capital" in result + assert "Coffee is a brewed drink" in result + + # Should NOT contain query + assert "What is coffee?" not in result + diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails.py b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails.py new file mode 100644 index 0000000000..3422e68957 --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_guardrails.py @@ -0,0 +1,263 @@ +""" +Unit tests for passthrough guardrails functionality. + +Tests the opt-in guardrail execution model for passthrough endpoints: +- Guardrails only run when explicitly configured +- Field-level targeting with JSONPath expressions +- Automatic inheritance from org/team/key levels when enabled +""" + +import pytest + +from litellm.proxy._types import PassThroughGuardrailSettings +from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor +from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( + PassthroughGuardrailHandler, +) + + +class TestPassthroughGuardrailHandlerIsEnabled: + """Tests for PassthroughGuardrailHandler.is_enabled method.""" + + def test_returns_false_when_config_is_none(self): + """Guardrails should be disabled when config is None.""" + result = PassthroughGuardrailHandler.is_enabled(None) + assert result is False + + def test_returns_false_when_config_is_empty_dict(self): + """Guardrails should be disabled when config is empty dict.""" + result = PassthroughGuardrailHandler.is_enabled({}) + assert result is False + + def test_returns_true_when_config_is_list(self): + """Guardrails should be enabled when config is a list of names.""" + result = PassthroughGuardrailHandler.is_enabled(["pii-detection"]) + assert result is True + + def test_returns_false_when_config_is_invalid_type(self): + """Guardrails should be disabled when config is not a dict or list.""" + result = PassthroughGuardrailHandler.is_enabled("pii-detection") # type: ignore + assert result is False + + def test_returns_true_when_config_has_guardrails(self): + """Guardrails should be enabled when config has at least one guardrail.""" + config = {"pii-detection": None} + result = PassthroughGuardrailHandler.is_enabled(config) + assert result is True + + def test_returns_true_with_multiple_guardrails(self): + """Guardrails should be enabled with multiple guardrails configured.""" + config = { + "pii-detection": None, + "content-moderation": {"request_fields": ["input"]}, + } + result = PassthroughGuardrailHandler.is_enabled(config) + assert result is True + + +class TestPassthroughGuardrailHandlerGetGuardrailNames: + """Tests for PassthroughGuardrailHandler.get_guardrail_names method.""" + + def test_returns_empty_list_when_disabled(self): + """Should return empty list when guardrails are disabled.""" + result = PassthroughGuardrailHandler.get_guardrail_names(None) + assert result == [] + + def test_returns_guardrail_names(self): + """Should return list of guardrail names from config.""" + config = { + "pii-detection": None, + "content-moderation": {"request_fields": ["input"]}, + } + result = PassthroughGuardrailHandler.get_guardrail_names(config) + assert set(result) == {"pii-detection", "content-moderation"} + + +class TestPassthroughGuardrailHandlerNormalizeConfig: + """Tests for PassthroughGuardrailHandler.normalize_config method.""" + + def test_normalizes_list_to_dict(self): + """List of guardrail names should be converted to dict with None values.""" + config = ["pii-detection", "content-moderation"] + result = PassthroughGuardrailHandler.normalize_config(config) + assert result == {"pii-detection": None, "content-moderation": None} + + def test_returns_dict_unchanged(self): + """Dict config should be returned as-is.""" + config = {"pii-detection": {"request_fields": ["query"]}} + result = PassthroughGuardrailHandler.normalize_config(config) + assert result == config + + +class TestPassthroughGuardrailHandlerGetSettings: + """Tests for PassthroughGuardrailHandler.get_settings method.""" + + def test_returns_none_when_config_is_none(self): + """Should return None when config is None.""" + result = PassthroughGuardrailHandler.get_settings(None, "pii-detection") + assert result is None + + def test_returns_none_when_guardrail_not_in_config(self): + """Should return None when guardrail is not in config.""" + config = {"pii-detection": None} + result = PassthroughGuardrailHandler.get_settings(config, "content-moderation") + assert result is None + + def test_returns_none_when_settings_is_none(self): + """Should return None when guardrail has no settings.""" + config = {"pii-detection": None} + result = PassthroughGuardrailHandler.get_settings(config, "pii-detection") + assert result is None + + def test_returns_settings_object(self): + """Should return PassThroughGuardrailSettings when settings are provided.""" + config = { + "pii-detection": { + "request_fields": ["input", "query"], + "response_fields": ["output"], + } + } + result = PassthroughGuardrailHandler.get_settings(config, "pii-detection") + assert result is not None + assert result.request_fields == ["input", "query"] + assert result.response_fields == ["output"] + + +class TestJsonPathExtractorEvaluate: + """Tests for JsonPathExtractor.evaluate method.""" + + def test_simple_key(self): + """Should extract simple key from dict.""" + data = {"query": "test query", "other": "value"} + result = JsonPathExtractor.evaluate(data, "query") + assert result == "test query" + + def test_nested_key(self): + """Should extract nested key from dict.""" + data = {"foo": {"bar": "nested value"}} + result = JsonPathExtractor.evaluate(data, "foo.bar") + assert result == "nested value" + + def test_array_wildcard(self): + """Should extract values from array using wildcard.""" + data = {"items": [{"text": "item1"}, {"text": "item2"}, {"text": "item3"}]} + result = JsonPathExtractor.evaluate(data, "items[*].text") + assert result == ["item1", "item2", "item3"] + + def test_messages_content(self): + """Should extract content from messages array.""" + data = { + "messages": [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + ] + } + result = JsonPathExtractor.evaluate(data, "messages[*].content") + assert result == ["You are helpful", "Hello"] + + def test_missing_key_returns_none(self): + """Should return None for missing key.""" + data = {"query": "test"} + result = JsonPathExtractor.evaluate(data, "missing") + assert result is None + + def test_empty_data_returns_none(self): + """Should return None for empty data.""" + result = JsonPathExtractor.evaluate({}, "query") + assert result is None + + def test_empty_expression_returns_none(self): + """Should return None for empty expression.""" + data = {"query": "test"} + result = JsonPathExtractor.evaluate(data, "") + assert result is None + + +class TestJsonPathExtractorExtractFields: + """Tests for JsonPathExtractor.extract_fields method.""" + + def test_extracts_multiple_fields(self): + """Should extract and concatenate multiple fields.""" + data = {"query": "search query", "input": "additional input"} + result = JsonPathExtractor.extract_fields(data, ["query", "input"]) + assert "search query" in result + assert "additional input" in result + + def test_extracts_array_fields(self): + """Should extract and concatenate array fields.""" + data = {"documents": [{"text": "doc1"}, {"text": "doc2"}]} + result = JsonPathExtractor.extract_fields(data, ["documents[*].text"]) + assert "doc1" in result + assert "doc2" in result + + def test_handles_missing_fields(self): + """Should handle missing fields gracefully.""" + data = {"query": "test"} + result = JsonPathExtractor.extract_fields(data, ["query", "missing"]) + assert result == "test" + + def test_empty_fields_returns_empty_string(self): + """Should return empty string for empty fields list.""" + data = {"query": "test"} + result = JsonPathExtractor.extract_fields(data, []) + assert result == "" + + +class TestPassthroughGuardrailHandlerPrepareInput: + """Tests for PassthroughGuardrailHandler.prepare_input method.""" + + def test_returns_full_payload_when_no_settings(self): + """Should return full JSON payload when no settings provided.""" + data = {"query": "test", "input": "value"} + result = PassthroughGuardrailHandler.prepare_input(data, None) + assert "query" in result + assert "input" in result + + def test_returns_full_payload_when_no_request_fields(self): + """Should return full JSON payload when request_fields not set.""" + data = {"query": "test", "input": "value"} + settings = PassThroughGuardrailSettings(response_fields=["output"]) + result = PassthroughGuardrailHandler.prepare_input(data, settings) + assert "query" in result + assert "input" in result + + def test_returns_targeted_fields(self): + """Should return only targeted fields when request_fields set.""" + data = {"query": "targeted", "input": "also targeted", "other": "ignored"} + settings = PassThroughGuardrailSettings(request_fields=["query", "input"]) + result = PassthroughGuardrailHandler.prepare_input(data, settings) + assert "targeted" in result + assert "also targeted" in result + assert "ignored" not in result + + +class TestPassthroughGuardrailHandlerPrepareOutput: + """Tests for PassthroughGuardrailHandler.prepare_output method.""" + + def test_returns_full_payload_when_no_settings(self): + """Should return full JSON payload when no settings provided.""" + data = {"results": [{"text": "result1"}], "output": "value"} + result = PassthroughGuardrailHandler.prepare_output(data, None) + assert "results" in result + assert "output" in result + + def test_returns_full_payload_when_no_response_fields(self): + """Should return full JSON payload when response_fields not set.""" + data = {"results": [{"text": "result1"}], "output": "value"} + settings = PassThroughGuardrailSettings(request_fields=["input"]) + result = PassthroughGuardrailHandler.prepare_output(data, settings) + assert "results" in result + assert "output" in result + + def test_returns_targeted_fields(self): + """Should return only targeted fields when response_fields set.""" + data = { + "results": [{"text": "targeted1"}, {"text": "targeted2"}], + "other": "ignored", + } + settings = PassThroughGuardrailSettings(response_fields=["results[*].text"]) + result = PassthroughGuardrailHandler.prepare_output(data, settings) + assert "targeted1" in result + assert "targeted2" in result + assert "ignored" not in result +