[Feat] Add guardrails for pass through endpoints (#17221)

* add PassThroughGuardrailsConfig

* init JsonPathExtractor

* feat PassthroughGuardrailHandler

* feat pt guardrails

* pt guardrails

* add Pass-Through Endpoint Guardrail Translation

* add PassThroughEndpointHandler

* execute simple guardrail config and dict settings

* TestPassthroughGuardrailHandlerNormalizeConfig

* add passthrough_guardrails_config on litellm logging obj

* add LiteLLMLoggingObj to base trasaltino

* cleaner _get_guardrail_settings

* update guardrails settings

* docs pt guardrail

* docs Guardrails on Pass-Through Endpoints

* fix typing

* fix typing

* test_no_fields_set_sends_full_body

* fix typing

* Potential fix for code scanning alert no. 3834: Clear-text logging of sensitive information

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff
2025-11-27 12:06:53 -08:00
committed by GitHub
parent ef1b3f954b
commit d612d71ef4
24 changed files with 1378 additions and 34 deletions

View File

@@ -0,0 +1,214 @@
# Guardrails on Pass-Through Endpoints
## Overview
| Property | Details |
|----------|---------|
| Description | Enable guardrail execution on LiteLLM pass-through endpoints with opt-in activation and automatic inheritance from org/team/key levels |
| Supported Guardrails | All LiteLLM guardrails (Bedrock, Aporia, Lakera, etc.) |
| Default Behavior | Guardrails are **disabled** on pass-through endpoints unless explicitly enabled |
## Quick Start
### 1. Define guardrails and pass-through endpoint
```yaml showLineNumbers title="config.yaml"
guardrails:
- guardrail_name: "pii-guard"
litellm_params:
guardrail: bedrock
mode: pre_call
guardrailIdentifier: "your-guardrail-id"
guardrailVersion: "1"
general_settings:
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
guardrails:
pii-guard:
```
### 2. Start proxy
```bash
litellm --config config.yaml
```
### 3. Test request
```bash
curl -X POST "http://localhost:4000/v1/rerank" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-1234" \
-d '{
"model": "rerank-english-v3.0",
"query": "What is the capital of France?",
"documents": ["Paris is the capital of France."]
}'
```
---
## Opt-In Behavior
| Configuration | Behavior |
|--------------|----------|
| `guardrails` not set | No guardrails execute (default) |
| `guardrails` set | All org/team/key + pass-through guardrails execute |
When guardrails are enabled, the system collects and executes:
- Org-level guardrails
- Team-level guardrails
- Key-level guardrails
- Pass-through specific guardrails
---
## How It Works
The diagram below shows what happens when a client makes a request to `/special/rerank` - a pass-through endpoint configured with guardrails in your `config.yaml`.
When guardrails are configured on a pass-through endpoint:
1. **Pre-call guardrails** run on the request before forwarding to the target API
2. If `request_fields` is specified (e.g., `["query"]`), only those fields are sent to the guardrail. Otherwise, the entire request payload is evaluated.
3. The request is forwarded to the target API only if guardrails pass
4. **Post-call guardrails** run on the response from the target API
5. If `response_fields` is specified (e.g., `["results[*].text"]`), only those fields are evaluated. Otherwise, the entire response is checked.
:::info
If the `guardrails` block is omitted or empty in your pass-through endpoint config, the request skips the guardrail flow entirely and goes directly to the target API.
:::
```mermaid
sequenceDiagram
participant Client
box rgb(200, 220, 255) LiteLLM Proxy
participant PassThrough as Pass-through Endpoint
participant Guardrails
end
participant Target as Target API (Cohere, etc.)
Client->>PassThrough: POST /special/rerank
Note over PassThrough,Guardrails: Collect passthrough + org/team/key guardrails
PassThrough->>Guardrails: Run pre_call (request_fields or full payload)
Guardrails-->>PassThrough: ✓ Pass / ✗ Block
PassThrough->>Target: Forward request
Target-->>PassThrough: Response
PassThrough->>Guardrails: Run post_call (response_fields or full payload)
Guardrails-->>PassThrough: ✓ Pass / ✗ Block
PassThrough-->>Client: Return response (or error)
```
---
## Field-Level Targeting
Target specific JSON fields instead of the entire request/response payload.
```yaml showLineNumbers title="config.yaml"
guardrails:
- guardrail_name: "pii-detection"
litellm_params:
guardrail: bedrock
mode: pre_call
guardrailIdentifier: "pii-guard-id"
guardrailVersion: "1"
- guardrail_name: "content-moderation"
litellm_params:
guardrail: bedrock
mode: post_call
guardrailIdentifier: "content-guard-id"
guardrailVersion: "1"
general_settings:
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
headers:
Authorization: "bearer os.environ/COHERE_API_KEY"
guardrails:
pii-detection:
request_fields: ["query", "documents[*].text"]
content-moderation:
response_fields: ["results[*].text"]
```
### Field Options
| Field | Description |
|-------|-------------|
| `request_fields` | JSONPath expressions for input (pre_call) |
| `response_fields` | JSONPath expressions for output (post_call) |
| Neither specified | Guardrail runs on entire payload |
### JSONPath Examples
| Expression | Matches |
|------------|---------|
| `query` | Single field named `query` |
| `documents[*].text` | All `text` fields in `documents` array |
| `messages[*].content` | All `content` fields in `messages` array |
---
## Configuration Examples
### Single guardrail on entire payload
```yaml showLineNumbers title="config.yaml"
guardrails:
- guardrail_name: "pii-detection"
litellm_params:
guardrail: bedrock
mode: pre_call
guardrailIdentifier: "your-id"
guardrailVersion: "1"
general_settings:
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
guardrails:
pii-detection:
```
### Multiple guardrails with mixed settings
```yaml showLineNumbers title="config.yaml"
guardrails:
- guardrail_name: "pii-detection"
litellm_params:
guardrail: bedrock
mode: pre_call
guardrailIdentifier: "pii-id"
guardrailVersion: "1"
- guardrail_name: "content-moderation"
litellm_params:
guardrail: bedrock
mode: post_call
guardrailIdentifier: "content-id"
guardrailVersion: "1"
- guardrail_name: "prompt-injection"
litellm_params:
guardrail: lakera
mode: pre_call
api_key: os.environ/LAKERA_API_KEY
general_settings:
pass_through_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
guardrails:
pii-detection:
request_fields: ["input", "query"]
content-moderation:
prompt-injection:
request_fields: ["messages[*].content"]
```

View File

@@ -419,7 +419,8 @@ const sidebars = {
]
},
"pass_through/vllm",
"proxy/pass_through"
"proxy/pass_through",
"proxy/pass_through_guardrails"
]
},
"rag_ingest",

View File

@@ -374,6 +374,9 @@ class Logging(LiteLLMLoggingBaseClass):
# Init Caching related details
self.caching_details: Optional[CachingDetails] = None
# Passthrough endpoint guardrails config for field targeting
self.passthrough_guardrails_config: Optional[Dict[str, Any]] = None
self.model_call_details: Dict[str, Any] = {
"litellm_trace_id": litellm_trace_id,
"litellm_call_id": litellm_call_id,

View File

@@ -41,6 +41,7 @@ class AnthropicMessagesHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input messages by applying guardrails to text content.
@@ -145,6 +146,7 @@ class AnthropicMessagesHandler(BaseTranslation):
self,
response: "AnthropicMessagesResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content.

View File

@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
class BaseTranslation(ABC):
@@ -11,6 +12,7 @@ class BaseTranslation(ABC):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
pass
@@ -19,5 +21,6 @@ class BaseTranslation(ABC):
self,
response: Any,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
pass

View File

@@ -5,7 +5,7 @@ This module provides guardrail translation support for the rerank endpoint.
The handler processes only the 'query' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
@@ -34,6 +34,7 @@ class CohereRerankHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input query by applying guardrails.
@@ -68,6 +69,7 @@ class CohereRerankHandler(BaseTranslation):
self,
response: "RerankResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response - not applicable for rerank.

View File

@@ -42,6 +42,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input messages by applying guardrails to text content.
@@ -148,6 +149,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
self,
response: "ModelResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content.

View File

@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text completion
The handler processes the 'prompt' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
@@ -32,6 +32,7 @@ class OpenAITextCompletionHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input prompt by applying guardrails to text content.
@@ -100,6 +101,7 @@ class OpenAITextCompletionHandler(BaseTranslation):
self,
response: "TextCompletionResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to completion text.

View File

@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's image generation
The handler processes the 'prompt' parameter for guardrails.
"""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
@@ -31,6 +31,7 @@ class OpenAIImageGenerationHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input prompt by applying guardrails to text content.
@@ -72,6 +73,7 @@ class OpenAIImageGenerationHandler(BaseTranslation):
self,
response: "ImageResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response - typically not needed for image generation.

View File

@@ -56,6 +56,7 @@ class OpenAIResponsesHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input by applying guardrails to text content.
@@ -177,6 +178,7 @@ class OpenAIResponsesHandler(BaseTranslation):
self,
response: "ResponsesAPIResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content.

View File

@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text-to-speech e
The handler processes the 'input' text parameter (output is audio, so no text to guardrail).
"""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
@@ -30,6 +30,7 @@ class OpenAITextToSpeechHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input text by applying guardrails.
@@ -72,6 +73,7 @@ class OpenAITextToSpeechHandler(BaseTranslation):
self,
response: "HttpxBinaryResponseContent",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output - not applicable for text-to-speech.

View File

@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's audio transcript
The handler processes the output transcribed text (input is audio, so no text to guardrail).
"""
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
@@ -30,6 +30,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation):
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input - not applicable for audio transcription.
@@ -54,6 +55,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation):
self,
response: "TranscriptionResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process output transcription by applying guardrails to transcribed text.

View File

@@ -0,0 +1,12 @@
"""
Pass-Through Endpoint Guardrail Translation
This module exists here (under litellm/llms/) so it can be auto-discovered by
load_guardrail_translation_mappings() which scans for guardrail_translation
directories under litellm/llms/.
The main passthrough endpoint implementation is in:
litellm/proxy/pass_through_endpoints/
See guardrail_translation/README.md for more details.
"""

View File

@@ -0,0 +1,41 @@
# Pass-Through Endpoint Guardrail Translation
## Why This Exists Here
This module is located under `litellm/llms/` (instead of with the main passthrough code) because:
1. **Auto-discovery**: The `load_guardrail_translation_mappings()` function in `litellm/llms/__init__.py` scans for `guardrail_translation/` directories under `litellm/llms/`
2. **Consistency**: All other guardrail translation handlers follow this pattern (e.g., `openai/chat/guardrail_translation/`, `anthropic/chat/guardrail_translation/`)
## Main Passthrough Implementation
The main passthrough endpoint implementation is in:
```
litellm/proxy/pass_through_endpoints/
├── pass_through_endpoints.py # Core passthrough routing logic
├── passthrough_guardrails.py # Guardrail collection and field targeting
├── jsonpath_extractor.py # JSONPath field extraction utility
└── ...
```
## What This Handler Does
The `PassThroughEndpointHandler` enables guardrails to run on passthrough endpoint requests by:
1. **Field Targeting**: Extracts specific fields from the request/response using JSONPath expressions configured in `request_fields` / `response_fields`
2. **Full Payload Fallback**: If no field targeting is configured, processes the entire payload
3. **Config Access**: Uses `get_passthrough_guardrails_config()` / `set_passthrough_guardrails_config()` helpers to access the passthrough guardrails configuration stored in request metadata
## Example Config
```yaml
passthrough_endpoints:
- path: "/v1/rerank"
target: "https://api.cohere.com/v1/rerank"
guardrails:
bedrock-pre-guard:
request_fields: ["query", "documents[*].text"]
response_fields: ["results[*].text"]
```

View File

@@ -0,0 +1,15 @@
"""Pass-Through Endpoint guardrail translation handler."""
from litellm.llms.pass_through.guardrail_translation.handler import (
PassThroughEndpointHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.pass_through: PassThroughEndpointHandler,
}
__all__ = [
"guardrail_translation_mappings",
"PassThroughEndpointHandler",
]

View File

@@ -0,0 +1,165 @@
"""
Pass-Through Endpoint Message Handler for Unified Guardrails
This module provides a handler for passthrough endpoint requests.
It uses the field targeting configuration from litellm_logging_obj
to extract specific fields for guardrail processing.
"""
from typing import TYPE_CHECKING, Any, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.proxy._types import PassThroughGuardrailSettings
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
class PassThroughEndpointHandler(BaseTranslation):
"""
Handler for processing passthrough endpoint requests with guardrails.
Uses passthrough_guardrails_config from litellm_logging_obj
to determine which fields to extract for guardrail processing.
"""
def _get_guardrail_settings(
self,
litellm_logging_obj: Optional["LiteLLMLoggingObj"],
guardrail_name: Optional[str],
) -> Optional[PassThroughGuardrailSettings]:
"""
Get the guardrail settings for a specific guardrail from logging_obj.
"""
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
PassthroughGuardrailHandler,
)
if litellm_logging_obj is None:
return None
passthrough_config = getattr(
litellm_logging_obj, "passthrough_guardrails_config", None
)
if not passthrough_config or not guardrail_name:
return None
return PassthroughGuardrailHandler.get_settings(
passthrough_config, guardrail_name
)
def _extract_text_for_guardrail(
self,
data: dict,
field_expressions: Optional[List[str]],
) -> str:
"""
Extract text from data for guardrail processing.
If field_expressions provided, extracts only those fields.
Otherwise, returns the full payload as JSON.
"""
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import (
JsonPathExtractor,
)
if field_expressions:
text = JsonPathExtractor.extract_fields(
data=data,
jsonpath_expressions=field_expressions,
)
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: Extracted targeted fields: %s",
text[:200] if text else None,
)
return text
# Use entire payload, excluding internal fields
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
payload_to_check = {
k: v
for k, v in data.items()
if not k.startswith("_") and k not in ("metadata", "litellm_logging_obj")
}
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: Using full payload for guardrail"
)
return safe_dumps(payload_to_check)
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
"""
Process input by applying guardrails to targeted fields or full payload.
"""
guardrail_name = guardrail_to_apply.guardrail_name
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: Processing input for guardrail=%s",
guardrail_name,
)
# Get field targeting settings for this guardrail
settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name)
field_expressions = settings.request_fields if settings else None
# Extract text to check
text_to_check = self._extract_text_for_guardrail(data, field_expressions)
if not text_to_check:
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: No text to check, skipping guardrail"
)
return data
# Apply guardrail
await guardrail_to_apply.apply_guardrail(
text=text_to_check,
request_data=data,
)
return data
async def process_output_response(
self,
response: Any,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
"""
Process output response by applying guardrails to targeted fields.
"""
if not isinstance(response, dict):
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: Response is not a dict, skipping"
)
return response
guardrail_name = guardrail_to_apply.guardrail_name
verbose_proxy_logger.debug(
"PassThroughEndpointHandler: Processing output for guardrail=%s",
guardrail_name,
)
# Get field targeting settings for this guardrail
settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name)
field_expressions = settings.response_fields if settings else None
# Extract text to check
text_to_check = self._extract_text_for_guardrail(response, field_expressions)
if not text_to_check:
return response
# Apply guardrail
await guardrail_to_apply.apply_guardrail(
text=text_to_check,
request_data=response,
)
return response

View File

@@ -1693,27 +1693,26 @@ class DynamoDBArgs(LiteLLMPydanticObjectBase):
assume_role_aws_session_name: Optional[str] = None
class PassThroughGuardrailConfig(LiteLLMPydanticObjectBase):
class PassThroughGuardrailSettings(LiteLLMPydanticObjectBase):
"""
Configuration for guardrails on passthrough endpoints.
Settings for a specific guardrail on a passthrough endpoint.
Passthrough endpoints are opt-in only for guardrails. Guardrails configured at
org/team/key levels will NOT execute unless explicitly enabled here.
Allows field-level targeting for guardrail execution.
"""
enabled: bool = Field(
default=False,
description="Whether to execute guardrails for this passthrough endpoint. When True, all org/team/key level guardrails will execute along with any passthrough-specific guardrails. When False (default), NO guardrails execute.",
)
specific: Optional[List[str]] = Field(
request_fields: Optional[List[str]] = Field(
default=None,
description="Optional list of guardrail names that are specific to this passthrough endpoint. These will execute in addition to org/team/key level guardrails when enabled=True.",
description="JSONPath expressions for input field targeting (pre_call). Examples: 'query', 'documents[*].text', 'messages[*].content'. If not specified, guardrail runs on entire request payload.",
)
target_fields: Optional[List[str]] = Field(
response_fields: Optional[List[str]] = Field(
default=None,
description="Optional list of JSON paths to target specific fields for guardrail execution. Examples: 'messages[*].content', 'input', 'messages[?(@.role=='user')].content'. If not specified, guardrails execute on entire payload.",
description="JSONPath expressions for output field targeting (post_call). Examples: 'results[*].text', 'output'. If not specified, guardrail runs on entire response payload.",
)
# Type alias for the guardrails dict: guardrail_name -> settings (or None for defaults)
PassThroughGuardrailsConfig = Dict[str, Optional[PassThroughGuardrailSettings]]
class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase):
id: Optional[str] = Field(
default=None,
@@ -1739,9 +1738,9 @@ class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase):
default=False,
description="Whether authentication is required for the pass-through endpoint. If True, requests to the endpoint will require a valid LiteLLM API key.",
)
guardrails: Optional[PassThroughGuardrailConfig] = Field(
guardrails: Optional[PassThroughGuardrailsConfig] = Field(
default=None,
description="Guardrail configuration for this passthrough endpoint. When enabled, org/team/key level guardrails will execute along with any passthrough-specific guardrails. Defaults to disabled (no guardrails execute).",
description="Guardrails configuration for this passthrough endpoint. Dict keys are guardrail names, values are optional settings for field targeting. When set, all org/team/key level guardrails will also execute. Defaults to None (no guardrails execute).",
)

View File

@@ -86,6 +86,7 @@ class UnifiedLLMGuardrails(CustomLogger):
data = await endpoint_translation.process_input_messages(
data=data,
guardrail_to_apply=guardrail_to_apply,
litellm_logging_obj=data.get("litellm_logging_obj"),
)
# Add guardrail to applied guardrails header
@@ -148,6 +149,7 @@ class UnifiedLLMGuardrails(CustomLogger):
response = await endpoint_translation.process_output_response(
response=response, # type: ignore
guardrail_to_apply=guardrail_to_apply,
litellm_logging_obj=data.get("litellm_logging_obj"),
)
# Add guardrail to applied guardrails header
add_guardrail_to_applied_guardrails_header(

View File

@@ -0,0 +1,95 @@
"""
JSONPath Extractor Module
Extracts field values from data using simple JSONPath-like expressions.
"""
from typing import Any, List, Union
from litellm._logging import verbose_proxy_logger
class JsonPathExtractor:
"""Extracts field values from data using JSONPath-like expressions."""
@staticmethod
def extract_fields(
data: dict,
jsonpath_expressions: List[str],
) -> str:
"""
Extract field values from data using JSONPath-like expressions.
Supports simple expressions like:
- "query" -> data["query"]
- "documents[*].text" -> all text fields from documents array
- "messages[*].content" -> all content fields from messages array
Returns concatenated string of all extracted values.
"""
extracted_values: List[str] = []
for expr in jsonpath_expressions:
try:
value = JsonPathExtractor.evaluate(data, expr)
if value:
if isinstance(value, list):
extracted_values.extend([str(v) for v in value if v])
else:
extracted_values.append(str(value))
except Exception as e:
verbose_proxy_logger.debug(
"Failed to extract field %s: %s", expr, str(e)
)
return "\n".join(extracted_values)
@staticmethod
def evaluate(data: dict, expr: str) -> Union[str, List[str], None]:
"""
Evaluate a simple JSONPath-like expression.
Supports:
- Simple key: "query" -> data["query"]
- Nested key: "foo.bar" -> data["foo"]["bar"]
- Array wildcard: "items[*].text" -> [item["text"] for item in data["items"]]
"""
if not expr or not data:
return None
parts = expr.replace("[*]", ".[*]").split(".")
current: Any = data
for i, part in enumerate(parts):
if current is None:
return None
if part == "[*]":
# Wildcard - current should be a list
if not isinstance(current, list):
return None
# Get remaining path
remaining_path = ".".join(parts[i + 1:])
if not remaining_path:
return current
# Recursively evaluate remaining path for each item
results = []
for item in current:
if isinstance(item, dict):
result = JsonPathExtractor.evaluate(item, remaining_path)
if result:
if isinstance(result, list):
results.extend(result)
else:
results.append(result)
return results if results else None
elif isinstance(current, dict):
current = current.get(part)
else:
return None
return current

View File

@@ -599,6 +599,7 @@ async def pass_through_request( # noqa: PLR0915
stream: Optional[bool] = None,
cost_per_request: Optional[float] = None,
custom_llm_provider: Optional[str] = None,
guardrails_config: Optional[dict] = None,
):
"""
Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called
@@ -614,8 +615,13 @@ async def pass_through_request( # noqa: PLR0915
query_params: The query params
stream: Whether to stream the response
cost_per_request: Optional field - cost per request to the target endpoint
custom_llm_provider: Optional field - custom LLM provider for the endpoint
guardrails_config: Optional field - guardrails configuration for passthrough endpoint
"""
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
PassthroughGuardrailHandler,
)
from litellm.proxy.proxy_server import proxy_logging_obj
#########################################################
@@ -664,6 +670,45 @@ async def pass_through_request( # noqa: PLR0915
)
)
### COLLECT GUARDRAILS FOR PASSTHROUGH ENDPOINT ###
# Passthrough endpoints are opt-in only for guardrails
# When enabled, collect guardrails from org/team/key levels + passthrough-specific
guardrails_to_run = PassthroughGuardrailHandler.collect_guardrails(
user_api_key_dict=user_api_key_dict,
passthrough_guardrails_config=guardrails_config,
)
# Add guardrails to metadata if any should run
if guardrails_to_run and len(guardrails_to_run) > 0:
if _parsed_body is None:
_parsed_body = {}
if "metadata" not in _parsed_body:
_parsed_body["metadata"] = {}
_parsed_body["metadata"]["guardrails"] = guardrails_to_run
verbose_proxy_logger.debug(
f"Added guardrails to passthrough request metadata: {guardrails_to_run}"
)
## LOGGING OBJECT ## - initialize before pre_call_hook so guardrails can access it
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
litellm_call_id=litellm_call_id,
function_id="1245",
)
# Store passthrough guardrails config on logging_obj for field targeting
logging_obj.passthrough_guardrails_config = guardrails_config
# Store logging_obj in data so guardrails can access it
if _parsed_body is None:
_parsed_body = {}
_parsed_body["litellm_logging_obj"] = logging_obj
### CALL HOOKS ### - modify incoming data / reject request before calling the model
_parsed_body = await proxy_logging_obj.pre_call_hook(
user_api_key_dict=user_api_key_dict,
@@ -675,18 +720,6 @@ async def pass_through_request( # noqa: PLR0915
params={"timeout": 600},
)
async_client = async_client_obj.client
# create logging object
start_time = datetime.now()
logging_obj = Logging(
model="unknown",
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
stream=False,
call_type="pass_through_endpoint",
start_time=start_time,
litellm_call_id=litellm_call_id,
function_id="1245",
)
passthrough_logging_payload = PassthroughStandardLoggingPayload(
url=str(url),
request_body=_parsed_body,
@@ -1011,6 +1044,7 @@ def create_pass_through_route(
custom_llm_provider: Optional[str] = None,
is_streaming_request: Optional[bool] = False,
query_params: Optional[dict] = None,
guardrails: Optional[Dict[str, Any]] = None,
):
# check if target is an adapter.py or a url
from litellm._uuid import uuid
@@ -1079,6 +1113,7 @@ def create_pass_through_route(
"forward_headers": _forward_headers,
"merge_query_params": _merge_query_params,
"cost_per_request": cost_per_request,
"guardrails": None,
}
if passthrough_params is not None:
@@ -1096,6 +1131,7 @@ def create_pass_through_route(
param_cost_per_request = target_params.get(
"cost_per_request", cost_per_request
)
param_guardrails = target_params.get("guardrails", None)
# Construct the full target URL with subpath if needed
full_target = (
@@ -1135,6 +1171,7 @@ def create_pass_through_route(
custom_body=final_custom_body,
cost_per_request=cast(Optional[float], param_cost_per_request),
custom_llm_provider=custom_llm_provider,
guardrails_config=param_guardrails,
)
return endpoint_func
@@ -1769,6 +1806,7 @@ class InitPassThroughEndpointHelpers:
dependencies: Optional[List],
cost_per_request: Optional[float],
endpoint_id: str,
guardrails: Optional[dict] = None,
):
"""Add exact path route for pass-through endpoint"""
route_key = f"{endpoint_id}:exact:{path}"
@@ -1799,6 +1837,7 @@ class InitPassThroughEndpointHelpers:
merge_query_params,
dependencies,
cost_per_request=cost_per_request,
guardrails=guardrails,
),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
dependencies=dependencies,
@@ -1817,6 +1856,7 @@ class InitPassThroughEndpointHelpers:
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}
@@ -1831,6 +1871,7 @@ class InitPassThroughEndpointHelpers:
dependencies: Optional[List],
cost_per_request: Optional[float],
endpoint_id: str,
guardrails: Optional[dict] = None,
):
"""Add wildcard route for sub-paths"""
wildcard_path = f"{path}/{{subpath:path}}"
@@ -1863,6 +1904,7 @@ class InitPassThroughEndpointHelpers:
dependencies,
include_subpath=True,
cost_per_request=cost_per_request,
guardrails=guardrails,
),
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
dependencies=dependencies,
@@ -1881,6 +1923,7 @@ class InitPassThroughEndpointHelpers:
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}
@@ -2057,6 +2100,9 @@ async def initialize_pass_through_endpoints(
if _target is None:
continue
# Get guardrails config if present
_guardrails = endpoint.get("guardrails", None)
# Add exact path route
verbose_proxy_logger.debug(
"Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id
@@ -2071,6 +2117,7 @@ async def initialize_pass_through_endpoints(
dependencies=_dependencies,
cost_per_request=endpoint.get("cost_per_request", None),
endpoint_id=endpoint_id,
guardrails=_guardrails,
)
visited_endpoints.add(f"{endpoint_id}:exact:{_path}")
@@ -2087,6 +2134,7 @@ async def initialize_pass_through_endpoints(
dependencies=_dependencies,
cost_per_request=endpoint.get("cost_per_request", None),
endpoint_id=endpoint_id,
guardrails=_guardrails,
)
visited_endpoints.add(f"{endpoint_id}:subpath:{_path}")

View File

@@ -0,0 +1,333 @@
"""
Passthrough Guardrails Helper Module
Handles guardrail execution for passthrough endpoints with:
- Opt-in model (guardrails only run when explicitly configured)
- Field-level targeting using JSONPath expressions
- Automatic inheritance from org/team/key levels when enabled
"""
from typing import Any, Dict, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
PassThroughGuardrailsConfig,
PassThroughGuardrailSettings,
UserAPIKeyAuth,
)
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
# Type for raw guardrails config input (before normalization)
# Can be a list of names or a dict with settings
PassThroughGuardrailsConfigInput = Union[
List[str], # Simple list: ["guard-1", "guard-2"]
PassThroughGuardrailsConfig, # Dict: {"guard-1": {"request_fields": [...]}}
]
class PassthroughGuardrailHandler:
"""
Handles guardrail execution for passthrough endpoints.
Passthrough endpoints use an opt-in model for guardrails:
- Guardrails only run when explicitly configured on the endpoint
- Supports field-level targeting using JSONPath expressions
- Automatically inherits org/team/key level guardrails when enabled
Guardrails can be specified as:
- List format (simple): ["guardrail-1", "guardrail-2"]
- Dict format (with settings): {"guardrail-1": {"request_fields": ["query"]}}
"""
@staticmethod
def normalize_config(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> Optional[PassThroughGuardrailsConfig]:
"""
Normalize guardrails config to dict format.
Accepts:
- List of guardrail names: ["g1", "g2"] -> {"g1": None, "g2": None}
- Dict with settings: {"g1": {"request_fields": [...]}}
- None: returns None
"""
if guardrails_config is None:
return None
# Already a dict - return as-is
if isinstance(guardrails_config, dict):
return guardrails_config
# List of guardrail names - convert to dict
if isinstance(guardrails_config, list):
return {name: None for name in guardrails_config}
verbose_proxy_logger.debug(
"Passthrough guardrails config is not a dict or list, got: %s",
type(guardrails_config),
)
return None
@staticmethod
def is_enabled(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> bool:
"""
Check if guardrails are enabled for a passthrough endpoint.
Passthrough endpoints are opt-in only - guardrails only run when
the guardrails config is set with at least one guardrail.
"""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return False
return len(normalized) > 0
@staticmethod
def get_guardrail_names(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> List[str]:
"""Get the list of guardrail names configured for a passthrough endpoint."""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return []
return list(normalized.keys())
@staticmethod
def get_settings(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
guardrail_name: str,
) -> Optional[PassThroughGuardrailSettings]:
"""Get settings for a specific guardrail from the passthrough config."""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return None
settings = normalized.get(guardrail_name)
if settings is None:
return None
if isinstance(settings, dict):
return PassThroughGuardrailSettings(**settings)
return settings
@staticmethod
def prepare_input(
request_data: dict,
guardrail_settings: Optional[PassThroughGuardrailSettings],
) -> str:
"""
Prepare input text for guardrail execution based on field targeting settings.
If request_fields is specified, extracts only those fields.
Otherwise, uses the entire request payload as text.
"""
if guardrail_settings is None or guardrail_settings.request_fields is None:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
return safe_dumps(request_data)
return JsonPathExtractor.extract_fields(
data=request_data,
jsonpath_expressions=guardrail_settings.request_fields,
)
@staticmethod
def prepare_output(
response_data: dict,
guardrail_settings: Optional[PassThroughGuardrailSettings],
) -> str:
"""
Prepare output text for guardrail execution based on field targeting settings.
If response_fields is specified, extracts only those fields.
Otherwise, uses the entire response payload as text.
"""
if guardrail_settings is None or guardrail_settings.response_fields is None:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
return safe_dumps(response_data)
return JsonPathExtractor.extract_fields(
data=response_data,
jsonpath_expressions=guardrail_settings.response_fields,
)
@staticmethod
async def execute(
request_data: dict,
user_api_key_dict: UserAPIKeyAuth,
guardrails_config: Optional[PassThroughGuardrailsConfig],
event_type: str = "pre_call",
) -> dict:
"""
Execute guardrails for a passthrough endpoint.
This is the main entry point for passthrough guardrail execution.
Args:
request_data: The request payload
user_api_key_dict: User API key authentication info
guardrails_config: Passthrough-specific guardrails configuration
event_type: "pre_call" for request, "post_call" for response
Returns:
The potentially modified request_data
Raises:
HTTPException if a guardrail blocks the request
"""
if not PassthroughGuardrailHandler.is_enabled(guardrails_config):
verbose_proxy_logger.debug(
"Passthrough guardrails not enabled, skipping guardrail execution"
)
return request_data
guardrail_names = PassthroughGuardrailHandler.get_guardrail_names(
guardrails_config
)
verbose_proxy_logger.debug(
"Executing passthrough guardrails: %s", guardrail_names
)
# Add to request metadata so guardrails know which to run
from litellm.proxy.pass_through_endpoints.passthrough_context import (
set_passthrough_guardrails_config,
)
if "metadata" not in request_data:
request_data["metadata"] = {}
# Set guardrails in metadata using dict format for compatibility
request_data["metadata"]["guardrails"] = {
name: True for name in guardrail_names
}
# Store passthrough guardrails config in request-scoped context
set_passthrough_guardrails_config(guardrails_config)
return request_data
@staticmethod
def collect_guardrails(
user_api_key_dict: UserAPIKeyAuth,
passthrough_guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> Optional[Dict[str, bool]]:
"""
Collect guardrails for a passthrough endpoint.
Passthrough endpoints are opt-in only for guardrails. Guardrails only run when
the guardrails config is set with at least one guardrail.
Accepts both list and dict formats:
- List: ["guardrail-1", "guardrail-2"]
- Dict: {"guardrail-1": {"request_fields": [...]}}
When enabled, this function collects:
- Passthrough-specific guardrails from the config
- Org/team/key level guardrails (automatic inheritance when passthrough is enabled)
Args:
user_api_key_dict: User API key authentication info
passthrough_guardrails_config: List or Dict of guardrail names/settings
Returns:
Dict of guardrail names to run (format: {guardrail_name: True}), or None
"""
from litellm.proxy.litellm_pre_call_utils import (
_add_guardrails_from_key_or_team_metadata,
)
# Normalize config to dict format (handles both list and dict)
normalized_config = PassthroughGuardrailHandler.normalize_config(
passthrough_guardrails_config
)
if normalized_config is None:
verbose_proxy_logger.debug(
"Passthrough guardrails not configured, skipping guardrail collection"
)
return None
if len(normalized_config) == 0:
verbose_proxy_logger.debug(
"Passthrough guardrails config is empty, skipping"
)
return None
# Passthrough is enabled - collect guardrails
guardrails_to_run: Dict[str, bool] = {}
# Add passthrough-specific guardrails
for guardrail_name in normalized_config.keys():
guardrails_to_run[guardrail_name] = True
verbose_proxy_logger.debug(
"Added passthrough-specific guardrail"
)
# Add org/team/key level guardrails using shared helper
temp_data: Dict[str, Any] = {"metadata": {}}
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=temp_data,
metadata_variable_name="metadata",
)
# Merge inherited guardrails into guardrails_to_run
inherited_guardrails = temp_data["metadata"].get("guardrails", [])
for guardrail_name in inherited_guardrails:
if guardrail_name not in guardrails_to_run:
guardrails_to_run[guardrail_name] = True
verbose_proxy_logger.debug(
"Added inherited guardrail (key/team level)"
)
verbose_proxy_logger.debug(
"Collected total guardrails for passthrough endpoint: %d",
len(guardrails_to_run),
)
return guardrails_to_run if guardrails_to_run else None
@staticmethod
def get_field_targeted_text(
data: dict,
guardrail_name: str,
is_request: bool = True,
) -> Optional[str]:
"""
Get the text to check for a guardrail, respecting field targeting settings.
Called by guardrail hooks to get the appropriate text based on
passthrough field targeting configuration.
Args:
data: The request/response data dict
guardrail_name: Name of the guardrail being executed
is_request: True for request (pre_call), False for response (post_call)
Returns:
The text to check, or None to use default behavior
"""
from litellm.proxy.pass_through_endpoints.passthrough_context import (
get_passthrough_guardrails_config,
)
passthrough_config = get_passthrough_guardrails_config()
if passthrough_config is None:
return None
settings = PassthroughGuardrailHandler.get_settings(
passthrough_config, guardrail_name
)
if settings is None:
return None
if is_request:
if settings.request_fields:
return JsonPathExtractor.extract_fields(data, settings.request_fields)
else:
if settings.response_fields:
return JsonPathExtractor.extract_fields(data, settings.response_fields)
return None

View File

@@ -3,7 +3,14 @@ model_list:
litellm_params:
model: bedrock/openai/arn:aws:bedrock:us-east-1:046319184608:imported-model/0m2lasirsp6z
guardrails:
- guardrail_name: "bedrock-pre-guard"
litellm_params:
guardrail: bedrock
mode: "pre_call"
guardrailIdentifier: ff6ujrregl1q
guardrailVersion: "DRAFT"
# like MCPs/vector stores
search_tools:
@@ -40,6 +47,14 @@ litellm_settings:
general_settings:
store_prompts_in_spend_logs: True
pass_through_endpoints:
- path: "/special/rerank"
target: "https://api.cohere.com/v1/rerank"
headers:
Authorization: "Bearer os.environ/COHERE_API_KEY"
guardrails:
bedrock-pre-guard:
request_fields: ["documents[*].text"]
vector_store_registry:

View File

@@ -0,0 +1,119 @@
"""
Test passthrough guardrails field-level targeting.
Tests that request_fields and response_fields correctly extract
and send only specified fields to the guardrail.
"""
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from litellm.proxy._types import PassThroughGuardrailSettings
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
PassthroughGuardrailHandler,
)
def test_no_fields_set_sends_full_body():
"""
Test that when no request_fields are set, the entire request body
is JSON dumped and sent to the guardrail.
"""
request_data = {
"model": "rerank-english-v3.0",
"query": "What is coffee?",
"documents": [
{"text": "Paris is the capital of France."},
{"text": "Coffee is a brewed drink."}
]
}
# No guardrail settings means full body
result = PassthroughGuardrailHandler.prepare_input(
request_data=request_data,
guardrail_settings=None
)
# Result should be JSON string of full request
assert isinstance(result, str)
result_dict = json.loads(result)
# Should contain all fields
assert "model" in result_dict
assert "query" in result_dict
assert "documents" in result_dict
assert result_dict["query"] == "What is coffee?"
assert len(result_dict["documents"]) == 2
def test_request_fields_query_only():
"""
Test that when request_fields is set to ["query"], only the query field
is extracted and sent to the guardrail.
"""
request_data = {
"model": "rerank-english-v3.0",
"query": "What is coffee?",
"documents": [
{"text": "Paris is the capital of France."},
{"text": "Coffee is a brewed drink."}
]
}
# Set request_fields to only extract query
guardrail_settings = PassThroughGuardrailSettings(
request_fields=["query"]
)
result = PassthroughGuardrailHandler.prepare_input(
request_data=request_data,
guardrail_settings=guardrail_settings
)
# Result should only contain query
assert isinstance(result, str)
assert "What is coffee?" in result
# Should NOT contain documents
assert "Paris is the capital" not in result
assert "Coffee is a brewed drink" not in result
def test_request_fields_documents_wildcard():
"""
Test that when request_fields is set to ["documents[*]"], only the documents
array is extracted and sent to the guardrail.
"""
request_data = {
"model": "rerank-english-v3.0",
"query": "What is coffee?",
"documents": [
{"text": "Paris is the capital of France."},
{"text": "Coffee is a brewed drink."}
]
}
# Set request_fields to extract documents array
guardrail_settings = PassThroughGuardrailSettings(
request_fields=["documents[*]"]
)
result = PassthroughGuardrailHandler.prepare_input(
request_data=request_data,
guardrail_settings=guardrail_settings
)
# Result should contain documents
assert isinstance(result, str)
assert "Paris is the capital" in result
assert "Coffee is a brewed drink" in result
# Should NOT contain query
assert "What is coffee?" not in result

View File

@@ -0,0 +1,263 @@
"""
Unit tests for passthrough guardrails functionality.
Tests the opt-in guardrail execution model for passthrough endpoints:
- Guardrails only run when explicitly configured
- Field-level targeting with JSONPath expressions
- Automatic inheritance from org/team/key levels when enabled
"""
import pytest
from litellm.proxy._types import PassThroughGuardrailSettings
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
PassthroughGuardrailHandler,
)
class TestPassthroughGuardrailHandlerIsEnabled:
"""Tests for PassthroughGuardrailHandler.is_enabled method."""
def test_returns_false_when_config_is_none(self):
"""Guardrails should be disabled when config is None."""
result = PassthroughGuardrailHandler.is_enabled(None)
assert result is False
def test_returns_false_when_config_is_empty_dict(self):
"""Guardrails should be disabled when config is empty dict."""
result = PassthroughGuardrailHandler.is_enabled({})
assert result is False
def test_returns_true_when_config_is_list(self):
"""Guardrails should be enabled when config is a list of names."""
result = PassthroughGuardrailHandler.is_enabled(["pii-detection"])
assert result is True
def test_returns_false_when_config_is_invalid_type(self):
"""Guardrails should be disabled when config is not a dict or list."""
result = PassthroughGuardrailHandler.is_enabled("pii-detection") # type: ignore
assert result is False
def test_returns_true_when_config_has_guardrails(self):
"""Guardrails should be enabled when config has at least one guardrail."""
config = {"pii-detection": None}
result = PassthroughGuardrailHandler.is_enabled(config)
assert result is True
def test_returns_true_with_multiple_guardrails(self):
"""Guardrails should be enabled with multiple guardrails configured."""
config = {
"pii-detection": None,
"content-moderation": {"request_fields": ["input"]},
}
result = PassthroughGuardrailHandler.is_enabled(config)
assert result is True
class TestPassthroughGuardrailHandlerGetGuardrailNames:
"""Tests for PassthroughGuardrailHandler.get_guardrail_names method."""
def test_returns_empty_list_when_disabled(self):
"""Should return empty list when guardrails are disabled."""
result = PassthroughGuardrailHandler.get_guardrail_names(None)
assert result == []
def test_returns_guardrail_names(self):
"""Should return list of guardrail names from config."""
config = {
"pii-detection": None,
"content-moderation": {"request_fields": ["input"]},
}
result = PassthroughGuardrailHandler.get_guardrail_names(config)
assert set(result) == {"pii-detection", "content-moderation"}
class TestPassthroughGuardrailHandlerNormalizeConfig:
"""Tests for PassthroughGuardrailHandler.normalize_config method."""
def test_normalizes_list_to_dict(self):
"""List of guardrail names should be converted to dict with None values."""
config = ["pii-detection", "content-moderation"]
result = PassthroughGuardrailHandler.normalize_config(config)
assert result == {"pii-detection": None, "content-moderation": None}
def test_returns_dict_unchanged(self):
"""Dict config should be returned as-is."""
config = {"pii-detection": {"request_fields": ["query"]}}
result = PassthroughGuardrailHandler.normalize_config(config)
assert result == config
class TestPassthroughGuardrailHandlerGetSettings:
"""Tests for PassthroughGuardrailHandler.get_settings method."""
def test_returns_none_when_config_is_none(self):
"""Should return None when config is None."""
result = PassthroughGuardrailHandler.get_settings(None, "pii-detection")
assert result is None
def test_returns_none_when_guardrail_not_in_config(self):
"""Should return None when guardrail is not in config."""
config = {"pii-detection": None}
result = PassthroughGuardrailHandler.get_settings(config, "content-moderation")
assert result is None
def test_returns_none_when_settings_is_none(self):
"""Should return None when guardrail has no settings."""
config = {"pii-detection": None}
result = PassthroughGuardrailHandler.get_settings(config, "pii-detection")
assert result is None
def test_returns_settings_object(self):
"""Should return PassThroughGuardrailSettings when settings are provided."""
config = {
"pii-detection": {
"request_fields": ["input", "query"],
"response_fields": ["output"],
}
}
result = PassthroughGuardrailHandler.get_settings(config, "pii-detection")
assert result is not None
assert result.request_fields == ["input", "query"]
assert result.response_fields == ["output"]
class TestJsonPathExtractorEvaluate:
"""Tests for JsonPathExtractor.evaluate method."""
def test_simple_key(self):
"""Should extract simple key from dict."""
data = {"query": "test query", "other": "value"}
result = JsonPathExtractor.evaluate(data, "query")
assert result == "test query"
def test_nested_key(self):
"""Should extract nested key from dict."""
data = {"foo": {"bar": "nested value"}}
result = JsonPathExtractor.evaluate(data, "foo.bar")
assert result == "nested value"
def test_array_wildcard(self):
"""Should extract values from array using wildcard."""
data = {"items": [{"text": "item1"}, {"text": "item2"}, {"text": "item3"}]}
result = JsonPathExtractor.evaluate(data, "items[*].text")
assert result == ["item1", "item2", "item3"]
def test_messages_content(self):
"""Should extract content from messages array."""
data = {
"messages": [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
]
}
result = JsonPathExtractor.evaluate(data, "messages[*].content")
assert result == ["You are helpful", "Hello"]
def test_missing_key_returns_none(self):
"""Should return None for missing key."""
data = {"query": "test"}
result = JsonPathExtractor.evaluate(data, "missing")
assert result is None
def test_empty_data_returns_none(self):
"""Should return None for empty data."""
result = JsonPathExtractor.evaluate({}, "query")
assert result is None
def test_empty_expression_returns_none(self):
"""Should return None for empty expression."""
data = {"query": "test"}
result = JsonPathExtractor.evaluate(data, "")
assert result is None
class TestJsonPathExtractorExtractFields:
"""Tests for JsonPathExtractor.extract_fields method."""
def test_extracts_multiple_fields(self):
"""Should extract and concatenate multiple fields."""
data = {"query": "search query", "input": "additional input"}
result = JsonPathExtractor.extract_fields(data, ["query", "input"])
assert "search query" in result
assert "additional input" in result
def test_extracts_array_fields(self):
"""Should extract and concatenate array fields."""
data = {"documents": [{"text": "doc1"}, {"text": "doc2"}]}
result = JsonPathExtractor.extract_fields(data, ["documents[*].text"])
assert "doc1" in result
assert "doc2" in result
def test_handles_missing_fields(self):
"""Should handle missing fields gracefully."""
data = {"query": "test"}
result = JsonPathExtractor.extract_fields(data, ["query", "missing"])
assert result == "test"
def test_empty_fields_returns_empty_string(self):
"""Should return empty string for empty fields list."""
data = {"query": "test"}
result = JsonPathExtractor.extract_fields(data, [])
assert result == ""
class TestPassthroughGuardrailHandlerPrepareInput:
"""Tests for PassthroughGuardrailHandler.prepare_input method."""
def test_returns_full_payload_when_no_settings(self):
"""Should return full JSON payload when no settings provided."""
data = {"query": "test", "input": "value"}
result = PassthroughGuardrailHandler.prepare_input(data, None)
assert "query" in result
assert "input" in result
def test_returns_full_payload_when_no_request_fields(self):
"""Should return full JSON payload when request_fields not set."""
data = {"query": "test", "input": "value"}
settings = PassThroughGuardrailSettings(response_fields=["output"])
result = PassthroughGuardrailHandler.prepare_input(data, settings)
assert "query" in result
assert "input" in result
def test_returns_targeted_fields(self):
"""Should return only targeted fields when request_fields set."""
data = {"query": "targeted", "input": "also targeted", "other": "ignored"}
settings = PassThroughGuardrailSettings(request_fields=["query", "input"])
result = PassthroughGuardrailHandler.prepare_input(data, settings)
assert "targeted" in result
assert "also targeted" in result
assert "ignored" not in result
class TestPassthroughGuardrailHandlerPrepareOutput:
"""Tests for PassthroughGuardrailHandler.prepare_output method."""
def test_returns_full_payload_when_no_settings(self):
"""Should return full JSON payload when no settings provided."""
data = {"results": [{"text": "result1"}], "output": "value"}
result = PassthroughGuardrailHandler.prepare_output(data, None)
assert "results" in result
assert "output" in result
def test_returns_full_payload_when_no_response_fields(self):
"""Should return full JSON payload when response_fields not set."""
data = {"results": [{"text": "result1"}], "output": "value"}
settings = PassThroughGuardrailSettings(request_fields=["input"])
result = PassthroughGuardrailHandler.prepare_output(data, settings)
assert "results" in result
assert "output" in result
def test_returns_targeted_fields(self):
"""Should return only targeted fields when response_fields set."""
data = {
"results": [{"text": "targeted1"}, {"text": "targeted2"}],
"other": "ignored",
}
settings = PassThroughGuardrailSettings(response_fields=["results[*].text"])
result = PassthroughGuardrailHandler.prepare_output(data, settings)
assert "targeted1" in result
assert "targeted2" in result
assert "ignored" not in result