mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Feat] Add guardrails for pass through endpoints (#17221)
* add PassThroughGuardrailsConfig * init JsonPathExtractor * feat PassthroughGuardrailHandler * feat pt guardrails * pt guardrails * add Pass-Through Endpoint Guardrail Translation * add PassThroughEndpointHandler * execute simple guardrail config and dict settings * TestPassthroughGuardrailHandlerNormalizeConfig * add passthrough_guardrails_config on litellm logging obj * add LiteLLMLoggingObj to base trasaltino * cleaner _get_guardrail_settings * update guardrails settings * docs pt guardrail * docs Guardrails on Pass-Through Endpoints * fix typing * fix typing * test_no_fields_set_sends_full_body * fix typing * Potential fix for code scanning alert no. 3834: Clear-text logging of sensitive information Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
This commit is contained in:
214
docs/my-website/docs/proxy/pass_through_guardrails.md
Normal file
214
docs/my-website/docs/proxy/pass_through_guardrails.md
Normal file
@@ -0,0 +1,214 @@
|
||||
# Guardrails on Pass-Through Endpoints
|
||||
|
||||
## Overview
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Description | Enable guardrail execution on LiteLLM pass-through endpoints with opt-in activation and automatic inheritance from org/team/key levels |
|
||||
| Supported Guardrails | All LiteLLM guardrails (Bedrock, Aporia, Lakera, etc.) |
|
||||
| Default Behavior | Guardrails are **disabled** on pass-through endpoints unless explicitly enabled |
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Define guardrails and pass-through endpoint
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
guardrails:
|
||||
- guardrail_name: "pii-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: pre_call
|
||||
guardrailIdentifier: "your-guardrail-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
general_settings:
|
||||
pass_through_endpoints:
|
||||
- path: "/v1/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
headers:
|
||||
Authorization: "bearer os.environ/COHERE_API_KEY"
|
||||
guardrails:
|
||||
pii-guard:
|
||||
```
|
||||
|
||||
### 2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
### 3. Test request
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/rerank" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": ["Paris is the capital of France."]
|
||||
}'
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Opt-In Behavior
|
||||
|
||||
| Configuration | Behavior |
|
||||
|--------------|----------|
|
||||
| `guardrails` not set | No guardrails execute (default) |
|
||||
| `guardrails` set | All org/team/key + pass-through guardrails execute |
|
||||
|
||||
When guardrails are enabled, the system collects and executes:
|
||||
- Org-level guardrails
|
||||
- Team-level guardrails
|
||||
- Key-level guardrails
|
||||
- Pass-through specific guardrails
|
||||
|
||||
---
|
||||
|
||||
|
||||
## How It Works
|
||||
|
||||
The diagram below shows what happens when a client makes a request to `/special/rerank` - a pass-through endpoint configured with guardrails in your `config.yaml`.
|
||||
|
||||
When guardrails are configured on a pass-through endpoint:
|
||||
1. **Pre-call guardrails** run on the request before forwarding to the target API
|
||||
2. If `request_fields` is specified (e.g., `["query"]`), only those fields are sent to the guardrail. Otherwise, the entire request payload is evaluated.
|
||||
3. The request is forwarded to the target API only if guardrails pass
|
||||
4. **Post-call guardrails** run on the response from the target API
|
||||
5. If `response_fields` is specified (e.g., `["results[*].text"]`), only those fields are evaluated. Otherwise, the entire response is checked.
|
||||
|
||||
:::info
|
||||
If the `guardrails` block is omitted or empty in your pass-through endpoint config, the request skips the guardrail flow entirely and goes directly to the target API.
|
||||
:::
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
box rgb(200, 220, 255) LiteLLM Proxy
|
||||
participant PassThrough as Pass-through Endpoint
|
||||
participant Guardrails
|
||||
end
|
||||
participant Target as Target API (Cohere, etc.)
|
||||
|
||||
Client->>PassThrough: POST /special/rerank
|
||||
Note over PassThrough,Guardrails: Collect passthrough + org/team/key guardrails
|
||||
PassThrough->>Guardrails: Run pre_call (request_fields or full payload)
|
||||
Guardrails-->>PassThrough: ✓ Pass / ✗ Block
|
||||
PassThrough->>Target: Forward request
|
||||
Target-->>PassThrough: Response
|
||||
PassThrough->>Guardrails: Run post_call (response_fields or full payload)
|
||||
Guardrails-->>PassThrough: ✓ Pass / ✗ Block
|
||||
PassThrough-->>Client: Return response (or error)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field-Level Targeting
|
||||
|
||||
Target specific JSON fields instead of the entire request/response payload.
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
guardrails:
|
||||
- guardrail_name: "pii-detection"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: pre_call
|
||||
guardrailIdentifier: "pii-guard-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
- guardrail_name: "content-moderation"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: post_call
|
||||
guardrailIdentifier: "content-guard-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
general_settings:
|
||||
pass_through_endpoints:
|
||||
- path: "/v1/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
headers:
|
||||
Authorization: "bearer os.environ/COHERE_API_KEY"
|
||||
guardrails:
|
||||
pii-detection:
|
||||
request_fields: ["query", "documents[*].text"]
|
||||
content-moderation:
|
||||
response_fields: ["results[*].text"]
|
||||
```
|
||||
|
||||
### Field Options
|
||||
|
||||
| Field | Description |
|
||||
|-------|-------------|
|
||||
| `request_fields` | JSONPath expressions for input (pre_call) |
|
||||
| `response_fields` | JSONPath expressions for output (post_call) |
|
||||
| Neither specified | Guardrail runs on entire payload |
|
||||
|
||||
### JSONPath Examples
|
||||
|
||||
| Expression | Matches |
|
||||
|------------|---------|
|
||||
| `query` | Single field named `query` |
|
||||
| `documents[*].text` | All `text` fields in `documents` array |
|
||||
| `messages[*].content` | All `content` fields in `messages` array |
|
||||
|
||||
---
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Single guardrail on entire payload
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
guardrails:
|
||||
- guardrail_name: "pii-detection"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: pre_call
|
||||
guardrailIdentifier: "your-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
general_settings:
|
||||
pass_through_endpoints:
|
||||
- path: "/v1/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
guardrails:
|
||||
pii-detection:
|
||||
```
|
||||
|
||||
### Multiple guardrails with mixed settings
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
guardrails:
|
||||
- guardrail_name: "pii-detection"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: pre_call
|
||||
guardrailIdentifier: "pii-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
- guardrail_name: "content-moderation"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: post_call
|
||||
guardrailIdentifier: "content-id"
|
||||
guardrailVersion: "1"
|
||||
|
||||
- guardrail_name: "prompt-injection"
|
||||
litellm_params:
|
||||
guardrail: lakera
|
||||
mode: pre_call
|
||||
api_key: os.environ/LAKERA_API_KEY
|
||||
|
||||
general_settings:
|
||||
pass_through_endpoints:
|
||||
- path: "/v1/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
guardrails:
|
||||
pii-detection:
|
||||
request_fields: ["input", "query"]
|
||||
content-moderation:
|
||||
prompt-injection:
|
||||
request_fields: ["messages[*].content"]
|
||||
```
|
||||
@@ -419,7 +419,8 @@ const sidebars = {
|
||||
]
|
||||
},
|
||||
"pass_through/vllm",
|
||||
"proxy/pass_through"
|
||||
"proxy/pass_through",
|
||||
"proxy/pass_through_guardrails"
|
||||
]
|
||||
},
|
||||
"rag_ingest",
|
||||
|
||||
@@ -374,6 +374,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
# Init Caching related details
|
||||
self.caching_details: Optional[CachingDetails] = None
|
||||
|
||||
# Passthrough endpoint guardrails config for field targeting
|
||||
self.passthrough_guardrails_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
self.model_call_details: Dict[str, Any] = {
|
||||
"litellm_trace_id": litellm_trace_id,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
|
||||
@@ -41,6 +41,7 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages by applying guardrails to text content.
|
||||
@@ -145,6 +146,7 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
self,
|
||||
response: "AnthropicMessagesResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class BaseTranslation(ABC):
|
||||
@@ -11,6 +12,7 @@ class BaseTranslation(ABC):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
@@ -19,5 +21,6 @@ class BaseTranslation(ABC):
|
||||
self,
|
||||
response: Any,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides guardrail translation support for the rerank endpoint.
|
||||
The handler processes only the 'query' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
@@ -34,6 +34,7 @@ class CohereRerankHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input query by applying guardrails.
|
||||
@@ -68,6 +69,7 @@ class CohereRerankHandler(BaseTranslation):
|
||||
self,
|
||||
response: "RerankResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - not applicable for rerank.
|
||||
|
||||
@@ -42,6 +42,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages by applying guardrails to text content.
|
||||
@@ -148,6 +149,7 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content.
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text completion
|
||||
The handler processes the 'prompt' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
@@ -32,6 +32,7 @@ class OpenAITextCompletionHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input prompt by applying guardrails to text content.
|
||||
@@ -100,6 +101,7 @@ class OpenAITextCompletionHandler(BaseTranslation):
|
||||
self,
|
||||
response: "TextCompletionResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to completion text.
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's image generation
|
||||
The handler processes the 'prompt' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
@@ -31,6 +31,7 @@ class OpenAIImageGenerationHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input prompt by applying guardrails to text content.
|
||||
@@ -72,6 +73,7 @@ class OpenAIImageGenerationHandler(BaseTranslation):
|
||||
self,
|
||||
response: "ImageResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - typically not needed for image generation.
|
||||
|
||||
@@ -56,6 +56,7 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input by applying guardrails to text content.
|
||||
@@ -177,6 +178,7 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
self,
|
||||
response: "ResponsesAPIResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content.
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's text-to-speech e
|
||||
The handler processes the 'input' text parameter (output is audio, so no text to guardrail).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
@@ -30,6 +30,7 @@ class OpenAITextToSpeechHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input text by applying guardrails.
|
||||
@@ -72,6 +73,7 @@ class OpenAITextToSpeechHandler(BaseTranslation):
|
||||
self,
|
||||
response: "HttpxBinaryResponseContent",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output - not applicable for text-to-speech.
|
||||
|
||||
@@ -5,7 +5,7 @@ This module provides guardrail translation support for OpenAI's audio transcript
|
||||
The handler processes the output transcribed text (input is audio, so no text to guardrail).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
@@ -30,6 +30,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation):
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input - not applicable for audio transcription.
|
||||
@@ -54,6 +55,7 @@ class OpenAIAudioTranscriptionHandler(BaseTranslation):
|
||||
self,
|
||||
response: "TranscriptionResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output transcription by applying guardrails to transcribed text.
|
||||
|
||||
12
litellm/llms/pass_through/__init__.py
Normal file
12
litellm/llms/pass_through/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Pass-Through Endpoint Guardrail Translation
|
||||
|
||||
This module exists here (under litellm/llms/) so it can be auto-discovered by
|
||||
load_guardrail_translation_mappings() which scans for guardrail_translation
|
||||
directories under litellm/llms/.
|
||||
|
||||
The main passthrough endpoint implementation is in:
|
||||
litellm/proxy/pass_through_endpoints/
|
||||
|
||||
See guardrail_translation/README.md for more details.
|
||||
"""
|
||||
41
litellm/llms/pass_through/guardrail_translation/README.md
Normal file
41
litellm/llms/pass_through/guardrail_translation/README.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Pass-Through Endpoint Guardrail Translation
|
||||
|
||||
## Why This Exists Here
|
||||
|
||||
This module is located under `litellm/llms/` (instead of with the main passthrough code) because:
|
||||
|
||||
1. **Auto-discovery**: The `load_guardrail_translation_mappings()` function in `litellm/llms/__init__.py` scans for `guardrail_translation/` directories under `litellm/llms/`
|
||||
2. **Consistency**: All other guardrail translation handlers follow this pattern (e.g., `openai/chat/guardrail_translation/`, `anthropic/chat/guardrail_translation/`)
|
||||
|
||||
## Main Passthrough Implementation
|
||||
|
||||
The main passthrough endpoint implementation is in:
|
||||
|
||||
```
|
||||
litellm/proxy/pass_through_endpoints/
|
||||
├── pass_through_endpoints.py # Core passthrough routing logic
|
||||
├── passthrough_guardrails.py # Guardrail collection and field targeting
|
||||
├── jsonpath_extractor.py # JSONPath field extraction utility
|
||||
└── ...
|
||||
```
|
||||
|
||||
## What This Handler Does
|
||||
|
||||
The `PassThroughEndpointHandler` enables guardrails to run on passthrough endpoint requests by:
|
||||
|
||||
1. **Field Targeting**: Extracts specific fields from the request/response using JSONPath expressions configured in `request_fields` / `response_fields`
|
||||
2. **Full Payload Fallback**: If no field targeting is configured, processes the entire payload
|
||||
3. **Config Access**: Uses `get_passthrough_guardrails_config()` / `set_passthrough_guardrails_config()` helpers to access the passthrough guardrails configuration stored in request metadata
|
||||
|
||||
## Example Config
|
||||
|
||||
```yaml
|
||||
passthrough_endpoints:
|
||||
- path: "/v1/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
guardrails:
|
||||
bedrock-pre-guard:
|
||||
request_fields: ["query", "documents[*].text"]
|
||||
response_fields: ["results[*].text"]
|
||||
```
|
||||
|
||||
15
litellm/llms/pass_through/guardrail_translation/__init__.py
Normal file
15
litellm/llms/pass_through/guardrail_translation/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Pass-Through Endpoint guardrail translation handler."""
|
||||
|
||||
from litellm.llms.pass_through.guardrail_translation.handler import (
|
||||
PassThroughEndpointHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.pass_through: PassThroughEndpointHandler,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"guardrail_translation_mappings",
|
||||
"PassThroughEndpointHandler",
|
||||
]
|
||||
165
litellm/llms/pass_through/guardrail_translation/handler.py
Normal file
165
litellm/llms/pass_through/guardrail_translation/handler.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Pass-Through Endpoint Message Handler for Unified Guardrails
|
||||
|
||||
This module provides a handler for passthrough endpoint requests.
|
||||
It uses the field targeting configuration from litellm_logging_obj
|
||||
to extract specific fields for guardrail processing.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.proxy._types import PassThroughGuardrailSettings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class PassThroughEndpointHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing passthrough endpoint requests with guardrails.
|
||||
|
||||
Uses passthrough_guardrails_config from litellm_logging_obj
|
||||
to determine which fields to extract for guardrail processing.
|
||||
"""
|
||||
|
||||
def _get_guardrail_settings(
|
||||
self,
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"],
|
||||
guardrail_name: Optional[str],
|
||||
) -> Optional[PassThroughGuardrailSettings]:
|
||||
"""
|
||||
Get the guardrail settings for a specific guardrail from logging_obj.
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
|
||||
PassthroughGuardrailHandler,
|
||||
)
|
||||
|
||||
if litellm_logging_obj is None:
|
||||
return None
|
||||
|
||||
passthrough_config = getattr(
|
||||
litellm_logging_obj, "passthrough_guardrails_config", None
|
||||
)
|
||||
if not passthrough_config or not guardrail_name:
|
||||
return None
|
||||
|
||||
return PassthroughGuardrailHandler.get_settings(
|
||||
passthrough_config, guardrail_name
|
||||
)
|
||||
|
||||
def _extract_text_for_guardrail(
|
||||
self,
|
||||
data: dict,
|
||||
field_expressions: Optional[List[str]],
|
||||
) -> str:
|
||||
"""
|
||||
Extract text from data for guardrail processing.
|
||||
|
||||
If field_expressions provided, extracts only those fields.
|
||||
Otherwise, returns the full payload as JSON.
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import (
|
||||
JsonPathExtractor,
|
||||
)
|
||||
|
||||
if field_expressions:
|
||||
text = JsonPathExtractor.extract_fields(
|
||||
data=data,
|
||||
jsonpath_expressions=field_expressions,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: Extracted targeted fields: %s",
|
||||
text[:200] if text else None,
|
||||
)
|
||||
return text
|
||||
|
||||
# Use entire payload, excluding internal fields
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
payload_to_check = {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if not k.startswith("_") and k not in ("metadata", "litellm_logging_obj")
|
||||
}
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: Using full payload for guardrail"
|
||||
)
|
||||
return safe_dumps(payload_to_check)
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input by applying guardrails to targeted fields or full payload.
|
||||
"""
|
||||
guardrail_name = guardrail_to_apply.guardrail_name
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: Processing input for guardrail=%s",
|
||||
guardrail_name,
|
||||
)
|
||||
|
||||
# Get field targeting settings for this guardrail
|
||||
settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name)
|
||||
field_expressions = settings.request_fields if settings else None
|
||||
|
||||
# Extract text to check
|
||||
text_to_check = self._extract_text_for_guardrail(data, field_expressions)
|
||||
|
||||
if not text_to_check:
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: No text to check, skipping guardrail"
|
||||
)
|
||||
return data
|
||||
|
||||
# Apply guardrail
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
text=text_to_check,
|
||||
request_data=data,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: Any,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to targeted fields.
|
||||
"""
|
||||
if not isinstance(response, dict):
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: Response is not a dict, skipping"
|
||||
)
|
||||
return response
|
||||
|
||||
guardrail_name = guardrail_to_apply.guardrail_name
|
||||
verbose_proxy_logger.debug(
|
||||
"PassThroughEndpointHandler: Processing output for guardrail=%s",
|
||||
guardrail_name,
|
||||
)
|
||||
|
||||
# Get field targeting settings for this guardrail
|
||||
settings = self._get_guardrail_settings(litellm_logging_obj, guardrail_name)
|
||||
field_expressions = settings.response_fields if settings else None
|
||||
|
||||
# Extract text to check
|
||||
text_to_check = self._extract_text_for_guardrail(response, field_expressions)
|
||||
|
||||
if not text_to_check:
|
||||
return response
|
||||
|
||||
# Apply guardrail
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
text=text_to_check,
|
||||
request_data=response,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -1693,27 +1693,26 @@ class DynamoDBArgs(LiteLLMPydanticObjectBase):
|
||||
assume_role_aws_session_name: Optional[str] = None
|
||||
|
||||
|
||||
class PassThroughGuardrailConfig(LiteLLMPydanticObjectBase):
|
||||
class PassThroughGuardrailSettings(LiteLLMPydanticObjectBase):
|
||||
"""
|
||||
Configuration for guardrails on passthrough endpoints.
|
||||
Settings for a specific guardrail on a passthrough endpoint.
|
||||
|
||||
Passthrough endpoints are opt-in only for guardrails. Guardrails configured at
|
||||
org/team/key levels will NOT execute unless explicitly enabled here.
|
||||
Allows field-level targeting for guardrail execution.
|
||||
"""
|
||||
enabled: bool = Field(
|
||||
default=False,
|
||||
description="Whether to execute guardrails for this passthrough endpoint. When True, all org/team/key level guardrails will execute along with any passthrough-specific guardrails. When False (default), NO guardrails execute.",
|
||||
)
|
||||
specific: Optional[List[str]] = Field(
|
||||
request_fields: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Optional list of guardrail names that are specific to this passthrough endpoint. These will execute in addition to org/team/key level guardrails when enabled=True.",
|
||||
description="JSONPath expressions for input field targeting (pre_call). Examples: 'query', 'documents[*].text', 'messages[*].content'. If not specified, guardrail runs on entire request payload.",
|
||||
)
|
||||
target_fields: Optional[List[str]] = Field(
|
||||
response_fields: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Optional list of JSON paths to target specific fields for guardrail execution. Examples: 'messages[*].content', 'input', 'messages[?(@.role=='user')].content'. If not specified, guardrails execute on entire payload.",
|
||||
description="JSONPath expressions for output field targeting (post_call). Examples: 'results[*].text', 'output'. If not specified, guardrail runs on entire response payload.",
|
||||
)
|
||||
|
||||
|
||||
# Type alias for the guardrails dict: guardrail_name -> settings (or None for defaults)
|
||||
PassThroughGuardrailsConfig = Dict[str, Optional[PassThroughGuardrailSettings]]
|
||||
|
||||
|
||||
class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase):
|
||||
id: Optional[str] = Field(
|
||||
default=None,
|
||||
@@ -1739,9 +1738,9 @@ class PassThroughGenericEndpoint(LiteLLMPydanticObjectBase):
|
||||
default=False,
|
||||
description="Whether authentication is required for the pass-through endpoint. If True, requests to the endpoint will require a valid LiteLLM API key.",
|
||||
)
|
||||
guardrails: Optional[PassThroughGuardrailConfig] = Field(
|
||||
guardrails: Optional[PassThroughGuardrailsConfig] = Field(
|
||||
default=None,
|
||||
description="Guardrail configuration for this passthrough endpoint. When enabled, org/team/key level guardrails will execute along with any passthrough-specific guardrails. Defaults to disabled (no guardrails execute).",
|
||||
description="Guardrails configuration for this passthrough endpoint. Dict keys are guardrail names, values are optional settings for field targeting. When set, all org/team/key level guardrails will also execute. Defaults to None (no guardrails execute).",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -86,6 +86,7 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
data = await endpoint_translation.process_input_messages(
|
||||
data=data,
|
||||
guardrail_to_apply=guardrail_to_apply,
|
||||
litellm_logging_obj=data.get("litellm_logging_obj"),
|
||||
)
|
||||
|
||||
# Add guardrail to applied guardrails header
|
||||
@@ -148,6 +149,7 @@ class UnifiedLLMGuardrails(CustomLogger):
|
||||
response = await endpoint_translation.process_output_response(
|
||||
response=response, # type: ignore
|
||||
guardrail_to_apply=guardrail_to_apply,
|
||||
litellm_logging_obj=data.get("litellm_logging_obj"),
|
||||
)
|
||||
# Add guardrail to applied guardrails header
|
||||
add_guardrail_to_applied_guardrails_header(
|
||||
|
||||
95
litellm/proxy/pass_through_endpoints/jsonpath_extractor.py
Normal file
95
litellm/proxy/pass_through_endpoints/jsonpath_extractor.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
JSONPath Extractor Module
|
||||
|
||||
Extracts field values from data using simple JSONPath-like expressions.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class JsonPathExtractor:
|
||||
"""Extracts field values from data using JSONPath-like expressions."""
|
||||
|
||||
@staticmethod
|
||||
def extract_fields(
|
||||
data: dict,
|
||||
jsonpath_expressions: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Extract field values from data using JSONPath-like expressions.
|
||||
|
||||
Supports simple expressions like:
|
||||
- "query" -> data["query"]
|
||||
- "documents[*].text" -> all text fields from documents array
|
||||
- "messages[*].content" -> all content fields from messages array
|
||||
|
||||
Returns concatenated string of all extracted values.
|
||||
"""
|
||||
extracted_values: List[str] = []
|
||||
|
||||
for expr in jsonpath_expressions:
|
||||
try:
|
||||
value = JsonPathExtractor.evaluate(data, expr)
|
||||
if value:
|
||||
if isinstance(value, list):
|
||||
extracted_values.extend([str(v) for v in value if v])
|
||||
else:
|
||||
extracted_values.append(str(value))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Failed to extract field %s: %s", expr, str(e)
|
||||
)
|
||||
|
||||
return "\n".join(extracted_values)
|
||||
|
||||
@staticmethod
|
||||
def evaluate(data: dict, expr: str) -> Union[str, List[str], None]:
|
||||
"""
|
||||
Evaluate a simple JSONPath-like expression.
|
||||
|
||||
Supports:
|
||||
- Simple key: "query" -> data["query"]
|
||||
- Nested key: "foo.bar" -> data["foo"]["bar"]
|
||||
- Array wildcard: "items[*].text" -> [item["text"] for item in data["items"]]
|
||||
"""
|
||||
if not expr or not data:
|
||||
return None
|
||||
|
||||
parts = expr.replace("[*]", ".[*]").split(".")
|
||||
current: Any = data
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
if part == "[*]":
|
||||
# Wildcard - current should be a list
|
||||
if not isinstance(current, list):
|
||||
return None
|
||||
|
||||
# Get remaining path
|
||||
remaining_path = ".".join(parts[i + 1:])
|
||||
if not remaining_path:
|
||||
return current
|
||||
|
||||
# Recursively evaluate remaining path for each item
|
||||
results = []
|
||||
for item in current:
|
||||
if isinstance(item, dict):
|
||||
result = JsonPathExtractor.evaluate(item, remaining_path)
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
results.extend(result)
|
||||
else:
|
||||
results.append(result)
|
||||
return results if results else None
|
||||
|
||||
elif isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
else:
|
||||
return None
|
||||
|
||||
return current
|
||||
|
||||
@@ -599,6 +599,7 @@ async def pass_through_request( # noqa: PLR0915
|
||||
stream: Optional[bool] = None,
|
||||
cost_per_request: Optional[float] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
guardrails_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Pass through endpoint handler, makes the httpx request for pass-through endpoints and ensures logging hooks are called
|
||||
@@ -614,8 +615,13 @@ async def pass_through_request( # noqa: PLR0915
|
||||
query_params: The query params
|
||||
stream: Whether to stream the response
|
||||
cost_per_request: Optional field - cost per request to the target endpoint
|
||||
custom_llm_provider: Optional field - custom LLM provider for the endpoint
|
||||
guardrails_config: Optional field - guardrails configuration for passthrough endpoint
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
|
||||
PassthroughGuardrailHandler,
|
||||
)
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
#########################################################
|
||||
@@ -664,6 +670,45 @@ async def pass_through_request( # noqa: PLR0915
|
||||
)
|
||||
)
|
||||
|
||||
### COLLECT GUARDRAILS FOR PASSTHROUGH ENDPOINT ###
|
||||
# Passthrough endpoints are opt-in only for guardrails
|
||||
# When enabled, collect guardrails from org/team/key levels + passthrough-specific
|
||||
guardrails_to_run = PassthroughGuardrailHandler.collect_guardrails(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
passthrough_guardrails_config=guardrails_config,
|
||||
)
|
||||
|
||||
# Add guardrails to metadata if any should run
|
||||
if guardrails_to_run and len(guardrails_to_run) > 0:
|
||||
if _parsed_body is None:
|
||||
_parsed_body = {}
|
||||
if "metadata" not in _parsed_body:
|
||||
_parsed_body["metadata"] = {}
|
||||
_parsed_body["metadata"]["guardrails"] = guardrails_to_run
|
||||
verbose_proxy_logger.debug(
|
||||
f"Added guardrails to passthrough request metadata: {guardrails_to_run}"
|
||||
)
|
||||
|
||||
## LOGGING OBJECT ## - initialize before pre_call_hook so guardrails can access it
|
||||
start_time = datetime.now()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
litellm_call_id=litellm_call_id,
|
||||
function_id="1245",
|
||||
)
|
||||
|
||||
# Store passthrough guardrails config on logging_obj for field targeting
|
||||
logging_obj.passthrough_guardrails_config = guardrails_config
|
||||
|
||||
# Store logging_obj in data so guardrails can access it
|
||||
if _parsed_body is None:
|
||||
_parsed_body = {}
|
||||
_parsed_body["litellm_logging_obj"] = logging_obj
|
||||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
_parsed_body = await proxy_logging_obj.pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
@@ -675,18 +720,6 @@ async def pass_through_request( # noqa: PLR0915
|
||||
params={"timeout": 600},
|
||||
)
|
||||
async_client = async_client_obj.client
|
||||
|
||||
# create logging object
|
||||
start_time = datetime.now()
|
||||
logging_obj = Logging(
|
||||
model="unknown",
|
||||
messages=[{"role": "user", "content": safe_dumps(_parsed_body)}],
|
||||
stream=False,
|
||||
call_type="pass_through_endpoint",
|
||||
start_time=start_time,
|
||||
litellm_call_id=litellm_call_id,
|
||||
function_id="1245",
|
||||
)
|
||||
passthrough_logging_payload = PassthroughStandardLoggingPayload(
|
||||
url=str(url),
|
||||
request_body=_parsed_body,
|
||||
@@ -1011,6 +1044,7 @@ def create_pass_through_route(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
is_streaming_request: Optional[bool] = False,
|
||||
query_params: Optional[dict] = None,
|
||||
guardrails: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
# check if target is an adapter.py or a url
|
||||
from litellm._uuid import uuid
|
||||
@@ -1079,6 +1113,7 @@ def create_pass_through_route(
|
||||
"forward_headers": _forward_headers,
|
||||
"merge_query_params": _merge_query_params,
|
||||
"cost_per_request": cost_per_request,
|
||||
"guardrails": None,
|
||||
}
|
||||
|
||||
if passthrough_params is not None:
|
||||
@@ -1096,6 +1131,7 @@ def create_pass_through_route(
|
||||
param_cost_per_request = target_params.get(
|
||||
"cost_per_request", cost_per_request
|
||||
)
|
||||
param_guardrails = target_params.get("guardrails", None)
|
||||
|
||||
# Construct the full target URL with subpath if needed
|
||||
full_target = (
|
||||
@@ -1135,6 +1171,7 @@ def create_pass_through_route(
|
||||
custom_body=final_custom_body,
|
||||
cost_per_request=cast(Optional[float], param_cost_per_request),
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
guardrails_config=param_guardrails,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
@@ -1769,6 +1806,7 @@ class InitPassThroughEndpointHelpers:
|
||||
dependencies: Optional[List],
|
||||
cost_per_request: Optional[float],
|
||||
endpoint_id: str,
|
||||
guardrails: Optional[dict] = None,
|
||||
):
|
||||
"""Add exact path route for pass-through endpoint"""
|
||||
route_key = f"{endpoint_id}:exact:{path}"
|
||||
@@ -1799,6 +1837,7 @@ class InitPassThroughEndpointHelpers:
|
||||
merge_query_params,
|
||||
dependencies,
|
||||
cost_per_request=cost_per_request,
|
||||
guardrails=guardrails,
|
||||
),
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
dependencies=dependencies,
|
||||
@@ -1817,6 +1856,7 @@ class InitPassThroughEndpointHelpers:
|
||||
"merge_query_params": merge_query_params,
|
||||
"dependencies": dependencies,
|
||||
"cost_per_request": cost_per_request,
|
||||
"guardrails": guardrails,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1831,6 +1871,7 @@ class InitPassThroughEndpointHelpers:
|
||||
dependencies: Optional[List],
|
||||
cost_per_request: Optional[float],
|
||||
endpoint_id: str,
|
||||
guardrails: Optional[dict] = None,
|
||||
):
|
||||
"""Add wildcard route for sub-paths"""
|
||||
wildcard_path = f"{path}/{{subpath:path}}"
|
||||
@@ -1863,6 +1904,7 @@ class InitPassThroughEndpointHelpers:
|
||||
dependencies,
|
||||
include_subpath=True,
|
||||
cost_per_request=cost_per_request,
|
||||
guardrails=guardrails,
|
||||
),
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
dependencies=dependencies,
|
||||
@@ -1881,6 +1923,7 @@ class InitPassThroughEndpointHelpers:
|
||||
"merge_query_params": merge_query_params,
|
||||
"dependencies": dependencies,
|
||||
"cost_per_request": cost_per_request,
|
||||
"guardrails": guardrails,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2057,6 +2100,9 @@ async def initialize_pass_through_endpoints(
|
||||
if _target is None:
|
||||
continue
|
||||
|
||||
# Get guardrails config if present
|
||||
_guardrails = endpoint.get("guardrails", None)
|
||||
|
||||
# Add exact path route
|
||||
verbose_proxy_logger.debug(
|
||||
"Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id
|
||||
@@ -2071,6 +2117,7 @@ async def initialize_pass_through_endpoints(
|
||||
dependencies=_dependencies,
|
||||
cost_per_request=endpoint.get("cost_per_request", None),
|
||||
endpoint_id=endpoint_id,
|
||||
guardrails=_guardrails,
|
||||
)
|
||||
|
||||
visited_endpoints.add(f"{endpoint_id}:exact:{_path}")
|
||||
@@ -2087,6 +2134,7 @@ async def initialize_pass_through_endpoints(
|
||||
dependencies=_dependencies,
|
||||
cost_per_request=endpoint.get("cost_per_request", None),
|
||||
endpoint_id=endpoint_id,
|
||||
guardrails=_guardrails,
|
||||
)
|
||||
|
||||
visited_endpoints.add(f"{endpoint_id}:subpath:{_path}")
|
||||
|
||||
333
litellm/proxy/pass_through_endpoints/passthrough_guardrails.py
Normal file
333
litellm/proxy/pass_through_endpoints/passthrough_guardrails.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Passthrough Guardrails Helper Module
|
||||
|
||||
Handles guardrail execution for passthrough endpoints with:
|
||||
- Opt-in model (guardrails only run when explicitly configured)
|
||||
- Field-level targeting using JSONPath expressions
|
||||
- Automatic inheritance from org/team/key levels when enabled
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
PassThroughGuardrailsConfig,
|
||||
PassThroughGuardrailSettings,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
|
||||
|
||||
# Type for raw guardrails config input (before normalization)
|
||||
# Can be a list of names or a dict with settings
|
||||
PassThroughGuardrailsConfigInput = Union[
|
||||
List[str], # Simple list: ["guard-1", "guard-2"]
|
||||
PassThroughGuardrailsConfig, # Dict: {"guard-1": {"request_fields": [...]}}
|
||||
]
|
||||
|
||||
|
||||
class PassthroughGuardrailHandler:
|
||||
"""
|
||||
Handles guardrail execution for passthrough endpoints.
|
||||
|
||||
Passthrough endpoints use an opt-in model for guardrails:
|
||||
- Guardrails only run when explicitly configured on the endpoint
|
||||
- Supports field-level targeting using JSONPath expressions
|
||||
- Automatically inherits org/team/key level guardrails when enabled
|
||||
|
||||
Guardrails can be specified as:
|
||||
- List format (simple): ["guardrail-1", "guardrail-2"]
|
||||
- Dict format (with settings): {"guardrail-1": {"request_fields": ["query"]}}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def normalize_config(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> Optional[PassThroughGuardrailsConfig]:
|
||||
"""
|
||||
Normalize guardrails config to dict format.
|
||||
|
||||
Accepts:
|
||||
- List of guardrail names: ["g1", "g2"] -> {"g1": None, "g2": None}
|
||||
- Dict with settings: {"g1": {"request_fields": [...]}}
|
||||
- None: returns None
|
||||
"""
|
||||
if guardrails_config is None:
|
||||
return None
|
||||
|
||||
# Already a dict - return as-is
|
||||
if isinstance(guardrails_config, dict):
|
||||
return guardrails_config
|
||||
|
||||
# List of guardrail names - convert to dict
|
||||
if isinstance(guardrails_config, list):
|
||||
return {name: None for name in guardrails_config}
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails config is not a dict or list, got: %s",
|
||||
type(guardrails_config),
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_enabled(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if guardrails are enabled for a passthrough endpoint.
|
||||
|
||||
Passthrough endpoints are opt-in only - guardrails only run when
|
||||
the guardrails config is set with at least one guardrail.
|
||||
"""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return False
|
||||
return len(normalized) > 0
|
||||
|
||||
@staticmethod
|
||||
def get_guardrail_names(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> List[str]:
|
||||
"""Get the list of guardrail names configured for a passthrough endpoint."""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return []
|
||||
return list(normalized.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_settings(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
guardrail_name: str,
|
||||
) -> Optional[PassThroughGuardrailSettings]:
|
||||
"""Get settings for a specific guardrail from the passthrough config."""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return None
|
||||
|
||||
settings = normalized.get(guardrail_name)
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
if isinstance(settings, dict):
|
||||
return PassThroughGuardrailSettings(**settings)
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def prepare_input(
|
||||
request_data: dict,
|
||||
guardrail_settings: Optional[PassThroughGuardrailSettings],
|
||||
) -> str:
|
||||
"""
|
||||
Prepare input text for guardrail execution based on field targeting settings.
|
||||
|
||||
If request_fields is specified, extracts only those fields.
|
||||
Otherwise, uses the entire request payload as text.
|
||||
"""
|
||||
if guardrail_settings is None or guardrail_settings.request_fields is None:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
return safe_dumps(request_data)
|
||||
|
||||
return JsonPathExtractor.extract_fields(
|
||||
data=request_data,
|
||||
jsonpath_expressions=guardrail_settings.request_fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_output(
|
||||
response_data: dict,
|
||||
guardrail_settings: Optional[PassThroughGuardrailSettings],
|
||||
) -> str:
|
||||
"""
|
||||
Prepare output text for guardrail execution based on field targeting settings.
|
||||
|
||||
If response_fields is specified, extracts only those fields.
|
||||
Otherwise, uses the entire response payload as text.
|
||||
"""
|
||||
if guardrail_settings is None or guardrail_settings.response_fields is None:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
return safe_dumps(response_data)
|
||||
|
||||
return JsonPathExtractor.extract_fields(
|
||||
data=response_data,
|
||||
jsonpath_expressions=guardrail_settings.response_fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
request_data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfig],
|
||||
event_type: str = "pre_call",
|
||||
) -> dict:
|
||||
"""
|
||||
Execute guardrails for a passthrough endpoint.
|
||||
|
||||
This is the main entry point for passthrough guardrail execution.
|
||||
|
||||
Args:
|
||||
request_data: The request payload
|
||||
user_api_key_dict: User API key authentication info
|
||||
guardrails_config: Passthrough-specific guardrails configuration
|
||||
event_type: "pre_call" for request, "post_call" for response
|
||||
|
||||
Returns:
|
||||
The potentially modified request_data
|
||||
|
||||
Raises:
|
||||
HTTPException if a guardrail blocks the request
|
||||
"""
|
||||
if not PassthroughGuardrailHandler.is_enabled(guardrails_config):
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails not enabled, skipping guardrail execution"
|
||||
)
|
||||
return request_data
|
||||
|
||||
guardrail_names = PassthroughGuardrailHandler.get_guardrail_names(
|
||||
guardrails_config
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Executing passthrough guardrails: %s", guardrail_names
|
||||
)
|
||||
|
||||
# Add to request metadata so guardrails know which to run
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_context import (
|
||||
set_passthrough_guardrails_config,
|
||||
)
|
||||
|
||||
if "metadata" not in request_data:
|
||||
request_data["metadata"] = {}
|
||||
|
||||
# Set guardrails in metadata using dict format for compatibility
|
||||
request_data["metadata"]["guardrails"] = {
|
||||
name: True for name in guardrail_names
|
||||
}
|
||||
|
||||
# Store passthrough guardrails config in request-scoped context
|
||||
set_passthrough_guardrails_config(guardrails_config)
|
||||
|
||||
return request_data
|
||||
|
||||
@staticmethod
|
||||
def collect_guardrails(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
passthrough_guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> Optional[Dict[str, bool]]:
|
||||
"""
|
||||
Collect guardrails for a passthrough endpoint.
|
||||
|
||||
Passthrough endpoints are opt-in only for guardrails. Guardrails only run when
|
||||
the guardrails config is set with at least one guardrail.
|
||||
|
||||
Accepts both list and dict formats:
|
||||
- List: ["guardrail-1", "guardrail-2"]
|
||||
- Dict: {"guardrail-1": {"request_fields": [...]}}
|
||||
|
||||
When enabled, this function collects:
|
||||
- Passthrough-specific guardrails from the config
|
||||
- Org/team/key level guardrails (automatic inheritance when passthrough is enabled)
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication info
|
||||
passthrough_guardrails_config: List or Dict of guardrail names/settings
|
||||
|
||||
Returns:
|
||||
Dict of guardrail names to run (format: {guardrail_name: True}), or None
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
_add_guardrails_from_key_or_team_metadata,
|
||||
)
|
||||
|
||||
# Normalize config to dict format (handles both list and dict)
|
||||
normalized_config = PassthroughGuardrailHandler.normalize_config(
|
||||
passthrough_guardrails_config
|
||||
)
|
||||
|
||||
if normalized_config is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails not configured, skipping guardrail collection"
|
||||
)
|
||||
return None
|
||||
|
||||
if len(normalized_config) == 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails config is empty, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Passthrough is enabled - collect guardrails
|
||||
guardrails_to_run: Dict[str, bool] = {}
|
||||
|
||||
# Add passthrough-specific guardrails
|
||||
for guardrail_name in normalized_config.keys():
|
||||
guardrails_to_run[guardrail_name] = True
|
||||
verbose_proxy_logger.debug(
|
||||
"Added passthrough-specific guardrail"
|
||||
)
|
||||
|
||||
# Add org/team/key level guardrails using shared helper
|
||||
temp_data: Dict[str, Any] = {"metadata": {}}
|
||||
_add_guardrails_from_key_or_team_metadata(
|
||||
key_metadata=user_api_key_dict.metadata,
|
||||
team_metadata=user_api_key_dict.team_metadata,
|
||||
data=temp_data,
|
||||
metadata_variable_name="metadata",
|
||||
)
|
||||
|
||||
# Merge inherited guardrails into guardrails_to_run
|
||||
inherited_guardrails = temp_data["metadata"].get("guardrails", [])
|
||||
for guardrail_name in inherited_guardrails:
|
||||
if guardrail_name not in guardrails_to_run:
|
||||
guardrails_to_run[guardrail_name] = True
|
||||
verbose_proxy_logger.debug(
|
||||
"Added inherited guardrail (key/team level)"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Collected total guardrails for passthrough endpoint: %d",
|
||||
len(guardrails_to_run),
|
||||
)
|
||||
|
||||
return guardrails_to_run if guardrails_to_run else None
|
||||
|
||||
@staticmethod
|
||||
def get_field_targeted_text(
|
||||
data: dict,
|
||||
guardrail_name: str,
|
||||
is_request: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the text to check for a guardrail, respecting field targeting settings.
|
||||
|
||||
Called by guardrail hooks to get the appropriate text based on
|
||||
passthrough field targeting configuration.
|
||||
|
||||
Args:
|
||||
data: The request/response data dict
|
||||
guardrail_name: Name of the guardrail being executed
|
||||
is_request: True for request (pre_call), False for response (post_call)
|
||||
|
||||
Returns:
|
||||
The text to check, or None to use default behavior
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_context import (
|
||||
get_passthrough_guardrails_config,
|
||||
)
|
||||
|
||||
passthrough_config = get_passthrough_guardrails_config()
|
||||
if passthrough_config is None:
|
||||
return None
|
||||
|
||||
settings = PassthroughGuardrailHandler.get_settings(
|
||||
passthrough_config, guardrail_name
|
||||
)
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
if is_request:
|
||||
if settings.request_fields:
|
||||
return JsonPathExtractor.extract_fields(data, settings.request_fields)
|
||||
else:
|
||||
if settings.response_fields:
|
||||
return JsonPathExtractor.extract_fields(data, settings.response_fields)
|
||||
|
||||
return None
|
||||
@@ -3,7 +3,14 @@ model_list:
|
||||
litellm_params:
|
||||
model: bedrock/openai/arn:aws:bedrock:us-east-1:046319184608:imported-model/0m2lasirsp6z
|
||||
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "bedrock-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: bedrock
|
||||
mode: "pre_call"
|
||||
guardrailIdentifier: ff6ujrregl1q
|
||||
guardrailVersion: "DRAFT"
|
||||
|
||||
|
||||
# like MCPs/vector stores
|
||||
search_tools:
|
||||
@@ -40,6 +47,14 @@ litellm_settings:
|
||||
|
||||
general_settings:
|
||||
store_prompts_in_spend_logs: True
|
||||
pass_through_endpoints:
|
||||
- path: "/special/rerank"
|
||||
target: "https://api.cohere.com/v1/rerank"
|
||||
headers:
|
||||
Authorization: "Bearer os.environ/COHERE_API_KEY"
|
||||
guardrails:
|
||||
bedrock-pre-guard:
|
||||
request_fields: ["documents[*].text"]
|
||||
|
||||
|
||||
vector_store_registry:
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Test passthrough guardrails field-level targeting.
|
||||
|
||||
Tests that request_fields and response_fields correctly extract
|
||||
and send only specified fields to the guardrail.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from litellm.proxy._types import PassThroughGuardrailSettings
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
|
||||
PassthroughGuardrailHandler,
|
||||
)
|
||||
|
||||
|
||||
def test_no_fields_set_sends_full_body():
|
||||
"""
|
||||
Test that when no request_fields are set, the entire request body
|
||||
is JSON dumped and sent to the guardrail.
|
||||
"""
|
||||
request_data = {
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is coffee?",
|
||||
"documents": [
|
||||
{"text": "Paris is the capital of France."},
|
||||
{"text": "Coffee is a brewed drink."}
|
||||
]
|
||||
}
|
||||
|
||||
# No guardrail settings means full body
|
||||
result = PassthroughGuardrailHandler.prepare_input(
|
||||
request_data=request_data,
|
||||
guardrail_settings=None
|
||||
)
|
||||
|
||||
# Result should be JSON string of full request
|
||||
assert isinstance(result, str)
|
||||
result_dict = json.loads(result)
|
||||
|
||||
# Should contain all fields
|
||||
assert "model" in result_dict
|
||||
assert "query" in result_dict
|
||||
assert "documents" in result_dict
|
||||
assert result_dict["query"] == "What is coffee?"
|
||||
assert len(result_dict["documents"]) == 2
|
||||
|
||||
|
||||
def test_request_fields_query_only():
|
||||
"""
|
||||
Test that when request_fields is set to ["query"], only the query field
|
||||
is extracted and sent to the guardrail.
|
||||
"""
|
||||
request_data = {
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is coffee?",
|
||||
"documents": [
|
||||
{"text": "Paris is the capital of France."},
|
||||
{"text": "Coffee is a brewed drink."}
|
||||
]
|
||||
}
|
||||
|
||||
# Set request_fields to only extract query
|
||||
guardrail_settings = PassThroughGuardrailSettings(
|
||||
request_fields=["query"]
|
||||
)
|
||||
|
||||
result = PassthroughGuardrailHandler.prepare_input(
|
||||
request_data=request_data,
|
||||
guardrail_settings=guardrail_settings
|
||||
)
|
||||
|
||||
# Result should only contain query
|
||||
assert isinstance(result, str)
|
||||
assert "What is coffee?" in result
|
||||
|
||||
# Should NOT contain documents
|
||||
assert "Paris is the capital" not in result
|
||||
assert "Coffee is a brewed drink" not in result
|
||||
|
||||
|
||||
def test_request_fields_documents_wildcard():
|
||||
"""
|
||||
Test that when request_fields is set to ["documents[*]"], only the documents
|
||||
array is extracted and sent to the guardrail.
|
||||
"""
|
||||
request_data = {
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is coffee?",
|
||||
"documents": [
|
||||
{"text": "Paris is the capital of France."},
|
||||
{"text": "Coffee is a brewed drink."}
|
||||
]
|
||||
}
|
||||
|
||||
# Set request_fields to extract documents array
|
||||
guardrail_settings = PassThroughGuardrailSettings(
|
||||
request_fields=["documents[*]"]
|
||||
)
|
||||
|
||||
result = PassthroughGuardrailHandler.prepare_input(
|
||||
request_data=request_data,
|
||||
guardrail_settings=guardrail_settings
|
||||
)
|
||||
|
||||
# Result should contain documents
|
||||
assert isinstance(result, str)
|
||||
assert "Paris is the capital" in result
|
||||
assert "Coffee is a brewed drink" in result
|
||||
|
||||
# Should NOT contain query
|
||||
assert "What is coffee?" not in result
|
||||
|
||||
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Unit tests for passthrough guardrails functionality.
|
||||
|
||||
Tests the opt-in guardrail execution model for passthrough endpoints:
|
||||
- Guardrails only run when explicitly configured
|
||||
- Field-level targeting with JSONPath expressions
|
||||
- Automatic inheritance from org/team/key levels when enabled
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy._types import PassThroughGuardrailSettings
|
||||
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_guardrails import (
|
||||
PassthroughGuardrailHandler,
|
||||
)
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerIsEnabled:
|
||||
"""Tests for PassthroughGuardrailHandler.is_enabled method."""
|
||||
|
||||
def test_returns_false_when_config_is_none(self):
|
||||
"""Guardrails should be disabled when config is None."""
|
||||
result = PassthroughGuardrailHandler.is_enabled(None)
|
||||
assert result is False
|
||||
|
||||
def test_returns_false_when_config_is_empty_dict(self):
|
||||
"""Guardrails should be disabled when config is empty dict."""
|
||||
result = PassthroughGuardrailHandler.is_enabled({})
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_config_is_list(self):
|
||||
"""Guardrails should be enabled when config is a list of names."""
|
||||
result = PassthroughGuardrailHandler.is_enabled(["pii-detection"])
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_config_is_invalid_type(self):
|
||||
"""Guardrails should be disabled when config is not a dict or list."""
|
||||
result = PassthroughGuardrailHandler.is_enabled("pii-detection") # type: ignore
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_config_has_guardrails(self):
|
||||
"""Guardrails should be enabled when config has at least one guardrail."""
|
||||
config = {"pii-detection": None}
|
||||
result = PassthroughGuardrailHandler.is_enabled(config)
|
||||
assert result is True
|
||||
|
||||
def test_returns_true_with_multiple_guardrails(self):
|
||||
"""Guardrails should be enabled with multiple guardrails configured."""
|
||||
config = {
|
||||
"pii-detection": None,
|
||||
"content-moderation": {"request_fields": ["input"]},
|
||||
}
|
||||
result = PassthroughGuardrailHandler.is_enabled(config)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerGetGuardrailNames:
|
||||
"""Tests for PassthroughGuardrailHandler.get_guardrail_names method."""
|
||||
|
||||
def test_returns_empty_list_when_disabled(self):
|
||||
"""Should return empty list when guardrails are disabled."""
|
||||
result = PassthroughGuardrailHandler.get_guardrail_names(None)
|
||||
assert result == []
|
||||
|
||||
def test_returns_guardrail_names(self):
|
||||
"""Should return list of guardrail names from config."""
|
||||
config = {
|
||||
"pii-detection": None,
|
||||
"content-moderation": {"request_fields": ["input"]},
|
||||
}
|
||||
result = PassthroughGuardrailHandler.get_guardrail_names(config)
|
||||
assert set(result) == {"pii-detection", "content-moderation"}
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerNormalizeConfig:
|
||||
"""Tests for PassthroughGuardrailHandler.normalize_config method."""
|
||||
|
||||
def test_normalizes_list_to_dict(self):
|
||||
"""List of guardrail names should be converted to dict with None values."""
|
||||
config = ["pii-detection", "content-moderation"]
|
||||
result = PassthroughGuardrailHandler.normalize_config(config)
|
||||
assert result == {"pii-detection": None, "content-moderation": None}
|
||||
|
||||
def test_returns_dict_unchanged(self):
|
||||
"""Dict config should be returned as-is."""
|
||||
config = {"pii-detection": {"request_fields": ["query"]}}
|
||||
result = PassthroughGuardrailHandler.normalize_config(config)
|
||||
assert result == config
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerGetSettings:
|
||||
"""Tests for PassthroughGuardrailHandler.get_settings method."""
|
||||
|
||||
def test_returns_none_when_config_is_none(self):
|
||||
"""Should return None when config is None."""
|
||||
result = PassthroughGuardrailHandler.get_settings(None, "pii-detection")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_guardrail_not_in_config(self):
|
||||
"""Should return None when guardrail is not in config."""
|
||||
config = {"pii-detection": None}
|
||||
result = PassthroughGuardrailHandler.get_settings(config, "content-moderation")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_when_settings_is_none(self):
|
||||
"""Should return None when guardrail has no settings."""
|
||||
config = {"pii-detection": None}
|
||||
result = PassthroughGuardrailHandler.get_settings(config, "pii-detection")
|
||||
assert result is None
|
||||
|
||||
def test_returns_settings_object(self):
|
||||
"""Should return PassThroughGuardrailSettings when settings are provided."""
|
||||
config = {
|
||||
"pii-detection": {
|
||||
"request_fields": ["input", "query"],
|
||||
"response_fields": ["output"],
|
||||
}
|
||||
}
|
||||
result = PassthroughGuardrailHandler.get_settings(config, "pii-detection")
|
||||
assert result is not None
|
||||
assert result.request_fields == ["input", "query"]
|
||||
assert result.response_fields == ["output"]
|
||||
|
||||
|
||||
class TestJsonPathExtractorEvaluate:
|
||||
"""Tests for JsonPathExtractor.evaluate method."""
|
||||
|
||||
def test_simple_key(self):
|
||||
"""Should extract simple key from dict."""
|
||||
data = {"query": "test query", "other": "value"}
|
||||
result = JsonPathExtractor.evaluate(data, "query")
|
||||
assert result == "test query"
|
||||
|
||||
def test_nested_key(self):
|
||||
"""Should extract nested key from dict."""
|
||||
data = {"foo": {"bar": "nested value"}}
|
||||
result = JsonPathExtractor.evaluate(data, "foo.bar")
|
||||
assert result == "nested value"
|
||||
|
||||
def test_array_wildcard(self):
|
||||
"""Should extract values from array using wildcard."""
|
||||
data = {"items": [{"text": "item1"}, {"text": "item2"}, {"text": "item3"}]}
|
||||
result = JsonPathExtractor.evaluate(data, "items[*].text")
|
||||
assert result == ["item1", "item2", "item3"]
|
||||
|
||||
def test_messages_content(self):
|
||||
"""Should extract content from messages array."""
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
}
|
||||
result = JsonPathExtractor.evaluate(data, "messages[*].content")
|
||||
assert result == ["You are helpful", "Hello"]
|
||||
|
||||
def test_missing_key_returns_none(self):
|
||||
"""Should return None for missing key."""
|
||||
data = {"query": "test"}
|
||||
result = JsonPathExtractor.evaluate(data, "missing")
|
||||
assert result is None
|
||||
|
||||
def test_empty_data_returns_none(self):
|
||||
"""Should return None for empty data."""
|
||||
result = JsonPathExtractor.evaluate({}, "query")
|
||||
assert result is None
|
||||
|
||||
def test_empty_expression_returns_none(self):
|
||||
"""Should return None for empty expression."""
|
||||
data = {"query": "test"}
|
||||
result = JsonPathExtractor.evaluate(data, "")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestJsonPathExtractorExtractFields:
|
||||
"""Tests for JsonPathExtractor.extract_fields method."""
|
||||
|
||||
def test_extracts_multiple_fields(self):
|
||||
"""Should extract and concatenate multiple fields."""
|
||||
data = {"query": "search query", "input": "additional input"}
|
||||
result = JsonPathExtractor.extract_fields(data, ["query", "input"])
|
||||
assert "search query" in result
|
||||
assert "additional input" in result
|
||||
|
||||
def test_extracts_array_fields(self):
|
||||
"""Should extract and concatenate array fields."""
|
||||
data = {"documents": [{"text": "doc1"}, {"text": "doc2"}]}
|
||||
result = JsonPathExtractor.extract_fields(data, ["documents[*].text"])
|
||||
assert "doc1" in result
|
||||
assert "doc2" in result
|
||||
|
||||
def test_handles_missing_fields(self):
|
||||
"""Should handle missing fields gracefully."""
|
||||
data = {"query": "test"}
|
||||
result = JsonPathExtractor.extract_fields(data, ["query", "missing"])
|
||||
assert result == "test"
|
||||
|
||||
def test_empty_fields_returns_empty_string(self):
|
||||
"""Should return empty string for empty fields list."""
|
||||
data = {"query": "test"}
|
||||
result = JsonPathExtractor.extract_fields(data, [])
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerPrepareInput:
|
||||
"""Tests for PassthroughGuardrailHandler.prepare_input method."""
|
||||
|
||||
def test_returns_full_payload_when_no_settings(self):
|
||||
"""Should return full JSON payload when no settings provided."""
|
||||
data = {"query": "test", "input": "value"}
|
||||
result = PassthroughGuardrailHandler.prepare_input(data, None)
|
||||
assert "query" in result
|
||||
assert "input" in result
|
||||
|
||||
def test_returns_full_payload_when_no_request_fields(self):
|
||||
"""Should return full JSON payload when request_fields not set."""
|
||||
data = {"query": "test", "input": "value"}
|
||||
settings = PassThroughGuardrailSettings(response_fields=["output"])
|
||||
result = PassthroughGuardrailHandler.prepare_input(data, settings)
|
||||
assert "query" in result
|
||||
assert "input" in result
|
||||
|
||||
def test_returns_targeted_fields(self):
|
||||
"""Should return only targeted fields when request_fields set."""
|
||||
data = {"query": "targeted", "input": "also targeted", "other": "ignored"}
|
||||
settings = PassThroughGuardrailSettings(request_fields=["query", "input"])
|
||||
result = PassthroughGuardrailHandler.prepare_input(data, settings)
|
||||
assert "targeted" in result
|
||||
assert "also targeted" in result
|
||||
assert "ignored" not in result
|
||||
|
||||
|
||||
class TestPassthroughGuardrailHandlerPrepareOutput:
|
||||
"""Tests for PassthroughGuardrailHandler.prepare_output method."""
|
||||
|
||||
def test_returns_full_payload_when_no_settings(self):
|
||||
"""Should return full JSON payload when no settings provided."""
|
||||
data = {"results": [{"text": "result1"}], "output": "value"}
|
||||
result = PassthroughGuardrailHandler.prepare_output(data, None)
|
||||
assert "results" in result
|
||||
assert "output" in result
|
||||
|
||||
def test_returns_full_payload_when_no_response_fields(self):
|
||||
"""Should return full JSON payload when response_fields not set."""
|
||||
data = {"results": [{"text": "result1"}], "output": "value"}
|
||||
settings = PassThroughGuardrailSettings(request_fields=["input"])
|
||||
result = PassthroughGuardrailHandler.prepare_output(data, settings)
|
||||
assert "results" in result
|
||||
assert "output" in result
|
||||
|
||||
def test_returns_targeted_fields(self):
|
||||
"""Should return only targeted fields when response_fields set."""
|
||||
data = {
|
||||
"results": [{"text": "targeted1"}, {"text": "targeted2"}],
|
||||
"other": "ignored",
|
||||
}
|
||||
settings = PassThroughGuardrailSettings(response_fields=["results[*].text"])
|
||||
result = PassthroughGuardrailHandler.prepare_output(data, settings)
|
||||
assert "targeted1" in result
|
||||
assert "targeted2" in result
|
||||
assert "ignored" not in result
|
||||
|
||||
Reference in New Issue
Block a user