Model Armor - Logging guardrail response on llm responses (#16977)

* Litellm dev 11 22 2025 p1 (#16975)

* fix(model_armor.py): return response after applying changes

* fix: initial commit adding guardrail span logging to otel on post-call runs

sends it as a separate span right now, need to include in the same llm request/response span

* fix(opentelemetry.py): include guardrail in received request log + set input/ouput fields on parent otel span instead of nesting it

allows request/response to be seen easily on observability tools

* fix(model_armor.py): working model armor logging on post call events

* fix: fix exception message

* fix(opentelemetry.py): add backwards compatibility for litellm_request

allow users building on the spec change to use previous spec
This commit is contained in:
Krish Dholakia
2025-11-22 15:44:28 -08:00
committed by GitHub
parent e11d34eb69
commit b9f2cc1c98
10 changed files with 169 additions and 103 deletions

View File

@@ -8,6 +8,18 @@ OpenTelemetry is a CNCF standard for observability. It connects to any observabi
<Image img={require('../../img/traceloop_dash.png')} />
:::note Change in v1.81.0
From v1.81.0, the request/response will be set as attributes on the parent "Received Proxy Server Request" span by default. This allows you to see the request/response in the parent span in your observability tool.
To use the older behavior with nested "litellm_request" spans, set the following environment variable:
```shell
USE_OTEL_LITELLM_REQUEST_SPAN=true
```
:::
## Getting Started
Install the OpenTelemetry SDK:

View File

@@ -99,13 +99,13 @@ class ArizeLogger(OpenTelemetry):
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
pass
def create_litellm_proxy_request_started_span(
self,
start_time: datetime,
headers: dict,
):
"""Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
pass
# def create_litellm_proxy_request_started_span(
# self,
# start_time: datetime,
# headers: dict,
# ):
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
# pass
async def async_health_check(self):
"""
@@ -117,14 +117,10 @@ class ArizeLogger(OpenTelemetry):
try:
config = self.get_arize_config()
# Prefer ARIZE_SPACE_KEY, but fall back to ARIZE_SPACE_ID for backwards compatibility
effective_space_key = config.space_key or config.space_id
if not effective_space_key:
if not config.space_id and not config.space_key:
return {
"status": "unhealthy",
# Tests (and users) expect the error message to reference ARIZE_SPACE_KEY
"error_message": "ARIZE_SPACE_KEY environment variable not set",
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
}
if not config.api_key:

View File

@@ -477,6 +477,7 @@ class CustomGuardrail(CustomLogger):
"""
# Convert None to empty dict to satisfy type requirements
guardrail_response = {} if response is None else response
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_response,
request_data=request_data,

View File

@@ -7,11 +7,13 @@ import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.secret_managers.main import get_secret_bool
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
ChatCompletionMessageToolCall,
CostBreakdown,
Function,
LLMResponseTypes,
StandardCallbackDynamicParams,
StandardLoggingPayload,
)
@@ -487,6 +489,28 @@ class OpenTelemetry(CustomLogger):
# End Parent OTEL Sspan
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: LLMResponseTypes,
):
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
litellm_logging_obj = data.get("litellm_logging_obj")
if litellm_logging_obj is not None and isinstance(
litellm_logging_obj, LiteLLMLogging
):
kwargs = litellm_logging_obj.model_call_details
parent_span = user_api_key_dict.parent_otel_span
ctx, _ = self._get_span_context(kwargs, default_span=parent_span)
# 3. Guardrail span
self._create_guardrail_span(kwargs=kwargs, context=ctx)
return response
#########################################################
# Team/Key Based Logging Control Flow
#########################################################
@@ -565,8 +589,15 @@ class OpenTelemetry(CustomLogger):
)
ctx, parent_span = self._get_span_context(kwargs)
if get_secret_bool("USE_OTEL_LITELLM_REQUEST_SPAN"):
primary_span_parent = None
else:
primary_span_parent = parent_span
# 1. Primary span
span = self._start_primary_span(kwargs, response_obj, start_time, end_time, ctx)
span = self._start_primary_span(
kwargs, response_obj, start_time, end_time, ctx, primary_span_parent
)
# 2. Rawrequest sub-span (if enabled)
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
@@ -585,11 +616,19 @@ class OpenTelemetry(CustomLogger):
if parent_span is not None:
parent_span.end(end_time=self._to_ns(datetime.now()))
def _start_primary_span(self, kwargs, response_obj, start_time, end_time, context):
def _start_primary_span(
self,
kwargs,
response_obj,
start_time,
end_time,
context,
parent_span: Optional[Span] = None,
):
from opentelemetry.trace import Status, StatusCode
otel_tracer: Tracer = self.get_tracer_to_use_for_request(kwargs)
span = otel_tracer.start_span(
span = parent_span or otel_tracer.start_span(
name=self._get_span_name(kwargs),
start_time=self._to_ns(start_time),
context=context,
@@ -779,6 +818,7 @@ class OpenTelemetry(CustomLogger):
guardrail_information_data = standard_logging_payload.get(
"guardrail_information"
)
if not guardrail_information_data:
return
@@ -1372,7 +1412,7 @@ class OpenTelemetry(CustomLogger):
return _parent_context
def _get_span_context(self, kwargs):
def _get_span_context(self, kwargs, default_span: Optional[Span] = None):
from opentelemetry import context, trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,

View File

@@ -3545,6 +3545,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
_in_memory_loggers.append(_arize_otel_logger)
return _arize_otel_logger # type: ignore
elif logging_integration == "arize_phoenix":
from litellm.integrations.opentelemetry import (
OpenTelemetry,
OpenTelemetryConfig,
@@ -3574,9 +3575,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
)
else:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = f"openinference.project.name={phoenix_project_name}"
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={phoenix_project_name}"
)
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:

View File

@@ -1,15 +1,21 @@
model_list:
- model_name: gpt-5-mini
litellm_params:
model: gpt-5-mini
- model_name: embedding-model
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: openai/text-embedding-3-large
- model_name: gpt-4o-mini-transcribe
litellm_params:
model: openai/gpt-4o-mini-transcribe
model: openai/gpt-3.5-turbo
api_key: os.environ/OPENAI_API_KEY
guardrails:
- guardrail_name: model-armor-shield
litellm_params:
guardrail: model_armor
mode: "post_call" # Run on both input and output
template_id: "test-prompt-template" # Required: Your Model Armor template ID
project_id: "test-vector-store-db" # Your GCP project ID
location: "us" # GCP location (default: us-central1)
mask_request_content: true # Enable request content masking
mask_response_content: true # Enable response content masking
fail_on_error: true # Fail request if Model Armor errors (default: true)
default_on: true # Run by default for all requests
litellm_settings:
callbacks: ["arize"]
callbacks: ["arize_phoenix"]

View File

@@ -1,14 +1,21 @@
from typing import Any, Dict, Iterable, List, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.utils import (
StandardLoggingGuardrailInformation,
StandardLoggingPayload,
)
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
def initialize_callbacks_on_proxy( # noqa: PLR0915
value: Any,
@@ -365,6 +372,26 @@ def add_guardrail_to_applied_guardrails_header(
_metadata["applied_guardrails"] = [guardrail_name]
def add_guardrail_response_to_standard_logging_object(
litellm_logging_obj: Optional["LiteLLMLogging"],
guardrail_response: StandardLoggingGuardrailInformation,
):
if litellm_logging_obj is None:
return
standard_logging_object: Optional[StandardLoggingPayload] = (
litellm_logging_obj.model_call_details.get("standard_logging_object")
)
if standard_logging_object is None:
return
guardrail_information = standard_logging_object.get("guardrail_information", [])
if guardrail_information is None:
guardrail_information = []
guardrail_information.append(guardrail_response)
standard_logging_object["guardrail_information"] = guardrail_information
return standard_logging_object
def get_metadata_variable_name_from_kwargs(
kwargs: dict,
) -> Literal["metadata", "litellm_metadata"]:

View File

@@ -29,10 +29,12 @@ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import (
CallTypesLiteral,
Choices,
GuardrailStatus,
ModelResponse,
ModelResponseStream,
StandardLoggingGuardrailInformation,
)
GUARDRAIL_NAME = "model_armor"
@@ -63,7 +65,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
GuardrailEventHooks.during_call,
GuardrailEventHooks.post_call,
]
# Initialize parent classes first
super().__init__(**kwargs)
VertexBase.__init__(self)
@@ -293,9 +295,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
filters = (
list(filter_results.values())
if isinstance(filter_results, dict)
else filter_results
if isinstance(filter_results, list)
else []
else filter_results if isinstance(filter_results, list) else []
)
# Prefer sanitized text from deidentifyResult if present
@@ -360,18 +360,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
"mcp_call",
"anthropic_messages",
],
call_type: CallTypesLiteral,
) -> Union[Exception, str, dict, None]:
"""Pre-call hook to sanitize user prompts."""
verbose_proxy_logger.debug("Inside Model Armor Pre-Call Hook")
@@ -475,16 +464,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"responses",
"mcp_call",
"anthropic_messages",
],
call_type: CallTypesLiteral,
) -> Union[Exception, str, dict, None]:
"""During-call hook to sanitize user prompts in parallel with LLM call."""
verbose_proxy_logger.debug("Inside Model Armor Moderation Hook")
@@ -582,6 +562,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
):
"""Post-call hook to sanitize model responses."""
from litellm.proxy.common_utils.callback_utils import (
add_guardrail_response_to_standard_logging_object,
add_guardrail_to_applied_guardrails_header,
)
@@ -610,15 +591,31 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
)
# Attach Model Armor response & status to this request's metadata to prevent race conditions
if isinstance(data, dict):
metadata = data.setdefault("metadata", {})
metadata["_model_armor_response"] = armor_response
metadata["_model_armor_status"] = (
"blocked"
if self._should_block_content(
armor_response, allow_sanitization=self.mask_response_content
if isinstance(armor_response, dict):
model_armor_logged_object = {
"model_armor_response": armor_response,
"model_armor_status": (
"blocked"
if self._should_block_content(
armor_response,
allow_sanitization=self.mask_response_content,
)
else "success"
),
}
standard_logging_guardrail_information = (
StandardLoggingGuardrailInformation(
guardrail_name=self.guardrail_name,
guardrail_provider="model_armor",
guardrail_mode=GuardrailEventHooks.post_call,
guardrail_response=model_armor_logged_object,
guardrail_status="success",
start_time=data.get("start_time"),
)
else "success"
)
add_guardrail_response_to_standard_logging_object(
litellm_logging_obj=data.get("litellm_logging_obj"),
guardrail_response=standard_logging_guardrail_information,
)
# Check if content should be blocked
@@ -658,6 +655,8 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
request_data=data, guardrail_name=self.guardrail_name
)
return response
async def async_post_call_streaming_iterator_hook(
self,
user_api_key_dict: UserAPIKeyAuth,

View File

@@ -398,41 +398,17 @@ class ProxyLogging:
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
for callback in litellm.callbacks:
if isinstance(callback, str):
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
cast(_custom_logger_compatible_callbacks_literal, callback),
internal_usage_cache=self.internal_usage_cache.dual_cache,
llm_router=llm_router,
)
if callback is None:
continue
if callback not in litellm.input_callback:
litellm.input_callback.append(callback) # type: ignore
if callback not in litellm.success_callback:
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
if callback not in litellm.failure_callback:
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
if callback not in litellm._async_success_callback:
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
if callback not in litellm._async_failure_callback:
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
if callback not in litellm.service_callback:
litellm.service_callback.append(callback) # type: ignore
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0
or len(litellm.failure_callback) > 0
):
callback_list = list(
set(
litellm.input_callback
+ litellm.success_callback
+ litellm.failure_callback
)
)
litellm.litellm_core_utils.litellm_logging.set_callbacks(
callback_list=callback_list
)
litellm.logging_callback_manager.add_litellm_callback(callback)
async def update_request_status(
self, litellm_call_id: str, status: Literal["success", "fail"]
@@ -1044,16 +1020,14 @@ class ProxyLogging:
event_type = GuardrailEventHooks.during_mcp_call
if (
callback.should_run_guardrail(
data=data, event_type=event_type
)
callback.should_run_guardrail(data=data, event_type=event_type)
is not True
):
continue
# Convert user_api_key_dict to proper format for async_moderation_hook
if call_type == "mcp_call":
user_api_key_auth_dict = (
self._convert_user_api_key_auth_to_dict(user_api_key_dict)
user_api_key_auth_dict = self._convert_user_api_key_auth_to_dict(
user_api_key_dict
)
else:
user_api_key_auth_dict = user_api_key_dict
@@ -1475,22 +1449,29 @@ class ProxyLogging:
):
continue
guardrail_response: Optional[Any] = None
if "apply_guardrail" in type(callback).__dict__:
data["guardrail_to_apply"] = callback
response = await unified_guardrail.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
guardrail_response = (
await unified_guardrail.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
)
else:
response = await callback.async_post_call_success_hook(
guardrail_response = await callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=response,
)
if guardrail_response is not None:
response = guardrail_response
############ Handle CustomLogger ###############################
#################################################################
for callback in other_callbacks:
await callback.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict, data=data, response=response

View File

@@ -117,5 +117,4 @@ async def test_chat_completion_check_otel_spans():
assert "postgres" in parent_trace_spans
assert "redis" in parent_trace_spans
assert "raw_gen_ai_request" in parent_trace_spans
assert "litellm_request" in parent_trace_spans
assert "batch_write_to_db" in parent_trace_spans