mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
Model Armor - Logging guardrail response on llm responses (#16977)
* Litellm dev 11 22 2025 p1 (#16975) * fix(model_armor.py): return response after applying changes * fix: initial commit adding guardrail span logging to otel on post-call runs sends it as a separate span right now, need to include in the same llm request/response span * fix(opentelemetry.py): include guardrail in received request log + set input/ouput fields on parent otel span instead of nesting it allows request/response to be seen easily on observability tools * fix(model_armor.py): working model armor logging on post call events * fix: fix exception message * fix(opentelemetry.py): add backwards compatibility for litellm_request allow users building on the spec change to use previous spec
This commit is contained in:
@@ -8,6 +8,18 @@ OpenTelemetry is a CNCF standard for observability. It connects to any observabi
|
||||
|
||||
<Image img={require('../../img/traceloop_dash.png')} />
|
||||
|
||||
:::note Change in v1.81.0
|
||||
|
||||
From v1.81.0, the request/response will be set as attributes on the parent "Received Proxy Server Request" span by default. This allows you to see the request/response in the parent span in your observability tool.
|
||||
|
||||
To use the older behavior with nested "litellm_request" spans, set the following environment variable:
|
||||
|
||||
```shell
|
||||
USE_OTEL_LITELLM_REQUEST_SPAN=true
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
## Getting Started
|
||||
|
||||
Install the OpenTelemetry SDK:
|
||||
|
||||
@@ -99,13 +99,13 @@ class ArizeLogger(OpenTelemetry):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending router+caching metrics adds bloat to arize logs"""
|
||||
pass
|
||||
|
||||
def create_litellm_proxy_request_started_span(
|
||||
self,
|
||||
start_time: datetime,
|
||||
headers: dict,
|
||||
):
|
||||
"""Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
pass
|
||||
# def create_litellm_proxy_request_started_span(
|
||||
# self,
|
||||
# start_time: datetime,
|
||||
# headers: dict,
|
||||
# ):
|
||||
# """Arize is used mainly for LLM I/O tracing, sending Proxy Server Request adds bloat to arize logs"""
|
||||
# pass
|
||||
|
||||
async def async_health_check(self):
|
||||
"""
|
||||
@@ -117,14 +117,10 @@ class ArizeLogger(OpenTelemetry):
|
||||
try:
|
||||
config = self.get_arize_config()
|
||||
|
||||
# Prefer ARIZE_SPACE_KEY, but fall back to ARIZE_SPACE_ID for backwards compatibility
|
||||
effective_space_key = config.space_key or config.space_id
|
||||
|
||||
if not effective_space_key:
|
||||
if not config.space_id and not config.space_key:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
# Tests (and users) expect the error message to reference ARIZE_SPACE_KEY
|
||||
"error_message": "ARIZE_SPACE_KEY environment variable not set",
|
||||
"error_message": "ARIZE_SPACE_ID or ARIZE_SPACE_KEY environment variable not set",
|
||||
}
|
||||
|
||||
if not config.api_key:
|
||||
|
||||
@@ -477,6 +477,7 @@ class CustomGuardrail(CustomLogger):
|
||||
"""
|
||||
# Convert None to empty dict to satisfy type requirements
|
||||
guardrail_response = {} if response is None else response
|
||||
|
||||
self.add_standard_logging_guardrail_information_to_request_data(
|
||||
guardrail_json_response=guardrail_response,
|
||||
request_data=request_data,
|
||||
|
||||
@@ -7,11 +7,13 @@ import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
CostBreakdown,
|
||||
Function,
|
||||
LLMResponseTypes,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
@@ -487,6 +489,28 @@ class OpenTelemetry(CustomLogger):
|
||||
# End Parent OTEL Sspan
|
||||
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: LLMResponseTypes,
|
||||
):
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
litellm_logging_obj = data.get("litellm_logging_obj")
|
||||
|
||||
if litellm_logging_obj is not None and isinstance(
|
||||
litellm_logging_obj, LiteLLMLogging
|
||||
):
|
||||
kwargs = litellm_logging_obj.model_call_details
|
||||
parent_span = user_api_key_dict.parent_otel_span
|
||||
|
||||
ctx, _ = self._get_span_context(kwargs, default_span=parent_span)
|
||||
|
||||
# 3. Guardrail span
|
||||
self._create_guardrail_span(kwargs=kwargs, context=ctx)
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
# Team/Key Based Logging Control Flow
|
||||
#########################################################
|
||||
@@ -565,8 +589,15 @@ class OpenTelemetry(CustomLogger):
|
||||
)
|
||||
ctx, parent_span = self._get_span_context(kwargs)
|
||||
|
||||
if get_secret_bool("USE_OTEL_LITELLM_REQUEST_SPAN"):
|
||||
primary_span_parent = None
|
||||
else:
|
||||
primary_span_parent = parent_span
|
||||
|
||||
# 1. Primary span
|
||||
span = self._start_primary_span(kwargs, response_obj, start_time, end_time, ctx)
|
||||
span = self._start_primary_span(
|
||||
kwargs, response_obj, start_time, end_time, ctx, primary_span_parent
|
||||
)
|
||||
|
||||
# 2. Raw‐request sub-span (if enabled)
|
||||
self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span)
|
||||
@@ -585,11 +616,19 @@ class OpenTelemetry(CustomLogger):
|
||||
if parent_span is not None:
|
||||
parent_span.end(end_time=self._to_ns(datetime.now()))
|
||||
|
||||
def _start_primary_span(self, kwargs, response_obj, start_time, end_time, context):
|
||||
def _start_primary_span(
|
||||
self,
|
||||
kwargs,
|
||||
response_obj,
|
||||
start_time,
|
||||
end_time,
|
||||
context,
|
||||
parent_span: Optional[Span] = None,
|
||||
):
|
||||
from opentelemetry.trace import Status, StatusCode
|
||||
|
||||
otel_tracer: Tracer = self.get_tracer_to_use_for_request(kwargs)
|
||||
span = otel_tracer.start_span(
|
||||
span = parent_span or otel_tracer.start_span(
|
||||
name=self._get_span_name(kwargs),
|
||||
start_time=self._to_ns(start_time),
|
||||
context=context,
|
||||
@@ -779,6 +818,7 @@ class OpenTelemetry(CustomLogger):
|
||||
guardrail_information_data = standard_logging_payload.get(
|
||||
"guardrail_information"
|
||||
)
|
||||
|
||||
if not guardrail_information_data:
|
||||
return
|
||||
|
||||
@@ -1372,7 +1412,7 @@ class OpenTelemetry(CustomLogger):
|
||||
|
||||
return _parent_context
|
||||
|
||||
def _get_span_context(self, kwargs):
|
||||
def _get_span_context(self, kwargs, default_span: Optional[Span] = None):
|
||||
from opentelemetry import context, trace
|
||||
from opentelemetry.trace.propagation.tracecontext import (
|
||||
TraceContextTextMapPropagator,
|
||||
|
||||
@@ -3545,6 +3545,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
_in_memory_loggers.append(_arize_otel_logger)
|
||||
return _arize_otel_logger # type: ignore
|
||||
elif logging_integration == "arize_phoenix":
|
||||
|
||||
from litellm.integrations.opentelemetry import (
|
||||
OpenTelemetry,
|
||||
OpenTelemetryConfig,
|
||||
@@ -3574,9 +3575,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
else:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = f"openinference.project.name={phoenix_project_name}"
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
model_list:
|
||||
- model_name: gpt-5-mini
|
||||
litellm_params:
|
||||
model: gpt-5-mini
|
||||
- model_name: embedding-model
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: openai/text-embedding-3-large
|
||||
|
||||
- model_name: gpt-4o-mini-transcribe
|
||||
litellm_params:
|
||||
model: openai/gpt-4o-mini-transcribe
|
||||
model: openai/gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: model-armor-shield
|
||||
litellm_params:
|
||||
guardrail: model_armor
|
||||
mode: "post_call" # Run on both input and output
|
||||
template_id: "test-prompt-template" # Required: Your Model Armor template ID
|
||||
project_id: "test-vector-store-db" # Your GCP project ID
|
||||
location: "us" # GCP location (default: us-central1)
|
||||
mask_request_content: true # Enable request content masking
|
||||
mask_response_content: true # Enable response content masking
|
||||
fail_on_error: true # Fail request if Model Armor errors (default: true)
|
||||
default_on: true # Run by default for all requests
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["arize"]
|
||||
callbacks: ["arize_phoenix"]
|
||||
@@ -1,14 +1,21 @@
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.utils import (
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingPayload,
|
||||
)
|
||||
|
||||
blue_color_code = "\033[94m"
|
||||
reset_color_code = "\033[0m"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
|
||||
def initialize_callbacks_on_proxy( # noqa: PLR0915
|
||||
value: Any,
|
||||
@@ -365,6 +372,26 @@ def add_guardrail_to_applied_guardrails_header(
|
||||
_metadata["applied_guardrails"] = [guardrail_name]
|
||||
|
||||
|
||||
def add_guardrail_response_to_standard_logging_object(
|
||||
litellm_logging_obj: Optional["LiteLLMLogging"],
|
||||
guardrail_response: StandardLoggingGuardrailInformation,
|
||||
):
|
||||
if litellm_logging_obj is None:
|
||||
return
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = (
|
||||
litellm_logging_obj.model_call_details.get("standard_logging_object")
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
return
|
||||
guardrail_information = standard_logging_object.get("guardrail_information", [])
|
||||
if guardrail_information is None:
|
||||
guardrail_information = []
|
||||
guardrail_information.append(guardrail_response)
|
||||
standard_logging_object["guardrail_information"] = guardrail_information
|
||||
|
||||
return standard_logging_object
|
||||
|
||||
|
||||
def get_metadata_variable_name_from_kwargs(
|
||||
kwargs: dict,
|
||||
) -> Literal["metadata", "litellm_metadata"]:
|
||||
|
||||
@@ -29,10 +29,12 @@ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.utils import (
|
||||
CallTypesLiteral,
|
||||
Choices,
|
||||
GuardrailStatus,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StandardLoggingGuardrailInformation,
|
||||
)
|
||||
|
||||
GUARDRAIL_NAME = "model_armor"
|
||||
@@ -63,7 +65,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
GuardrailEventHooks.during_call,
|
||||
GuardrailEventHooks.post_call,
|
||||
]
|
||||
|
||||
|
||||
# Initialize parent classes first
|
||||
super().__init__(**kwargs)
|
||||
VertexBase.__init__(self)
|
||||
@@ -293,9 +295,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
filters = (
|
||||
list(filter_results.values())
|
||||
if isinstance(filter_results, dict)
|
||||
else filter_results
|
||||
if isinstance(filter_results, list)
|
||||
else []
|
||||
else filter_results if isinstance(filter_results, list) else []
|
||||
)
|
||||
|
||||
# Prefer sanitized text from deidentifyResult if present
|
||||
@@ -360,18 +360,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"text_completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"pass_through_endpoint",
|
||||
"rerank",
|
||||
"mcp_call",
|
||||
"anthropic_messages",
|
||||
],
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
"""Pre-call hook to sanitize user prompts."""
|
||||
verbose_proxy_logger.debug("Inside Model Armor Pre-Call Hook")
|
||||
@@ -475,16 +464,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
call_type: Literal[
|
||||
"completion",
|
||||
"embeddings",
|
||||
"image_generation",
|
||||
"moderation",
|
||||
"audio_transcription",
|
||||
"responses",
|
||||
"mcp_call",
|
||||
"anthropic_messages",
|
||||
],
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Union[Exception, str, dict, None]:
|
||||
"""During-call hook to sanitize user prompts in parallel with LLM call."""
|
||||
verbose_proxy_logger.debug("Inside Model Armor Moderation Hook")
|
||||
@@ -582,6 +562,7 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
):
|
||||
"""Post-call hook to sanitize model responses."""
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
add_guardrail_response_to_standard_logging_object,
|
||||
add_guardrail_to_applied_guardrails_header,
|
||||
)
|
||||
|
||||
@@ -610,15 +591,31 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
)
|
||||
|
||||
# Attach Model Armor response & status to this request's metadata to prevent race conditions
|
||||
if isinstance(data, dict):
|
||||
metadata = data.setdefault("metadata", {})
|
||||
metadata["_model_armor_response"] = armor_response
|
||||
metadata["_model_armor_status"] = (
|
||||
"blocked"
|
||||
if self._should_block_content(
|
||||
armor_response, allow_sanitization=self.mask_response_content
|
||||
if isinstance(armor_response, dict):
|
||||
model_armor_logged_object = {
|
||||
"model_armor_response": armor_response,
|
||||
"model_armor_status": (
|
||||
"blocked"
|
||||
if self._should_block_content(
|
||||
armor_response,
|
||||
allow_sanitization=self.mask_response_content,
|
||||
)
|
||||
else "success"
|
||||
),
|
||||
}
|
||||
standard_logging_guardrail_information = (
|
||||
StandardLoggingGuardrailInformation(
|
||||
guardrail_name=self.guardrail_name,
|
||||
guardrail_provider="model_armor",
|
||||
guardrail_mode=GuardrailEventHooks.post_call,
|
||||
guardrail_response=model_armor_logged_object,
|
||||
guardrail_status="success",
|
||||
start_time=data.get("start_time"),
|
||||
)
|
||||
else "success"
|
||||
)
|
||||
add_guardrail_response_to_standard_logging_object(
|
||||
litellm_logging_obj=data.get("litellm_logging_obj"),
|
||||
guardrail_response=standard_logging_guardrail_information,
|
||||
)
|
||||
|
||||
# Check if content should be blocked
|
||||
@@ -658,6 +655,8 @@ class ModelArmorGuardrail(CustomGuardrail, VertexBase):
|
||||
request_data=data, guardrail_name=self.guardrail_name
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def async_post_call_streaming_iterator_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
|
||||
@@ -398,41 +398,17 @@ class ProxyLogging:
|
||||
litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, str):
|
||||
|
||||
callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore
|
||||
cast(_custom_logger_compatible_callbacks_literal, callback),
|
||||
internal_usage_cache=self.internal_usage_cache.dual_cache,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
if callback is None:
|
||||
continue
|
||||
if callback not in litellm.input_callback:
|
||||
litellm.input_callback.append(callback) # type: ignore
|
||||
if callback not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore
|
||||
if callback not in litellm.failure_callback:
|
||||
litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore
|
||||
if callback not in litellm._async_failure_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore
|
||||
if callback not in litellm.service_callback:
|
||||
litellm.service_callback.append(callback) # type: ignore
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
or len(litellm.failure_callback) > 0
|
||||
):
|
||||
callback_list = list(
|
||||
set(
|
||||
litellm.input_callback
|
||||
+ litellm.success_callback
|
||||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.litellm_core_utils.litellm_logging.set_callbacks(
|
||||
callback_list=callback_list
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(callback)
|
||||
|
||||
async def update_request_status(
|
||||
self, litellm_call_id: str, status: Literal["success", "fail"]
|
||||
@@ -1044,16 +1020,14 @@ class ProxyLogging:
|
||||
event_type = GuardrailEventHooks.during_mcp_call
|
||||
|
||||
if (
|
||||
callback.should_run_guardrail(
|
||||
data=data, event_type=event_type
|
||||
)
|
||||
callback.should_run_guardrail(data=data, event_type=event_type)
|
||||
is not True
|
||||
):
|
||||
continue
|
||||
# Convert user_api_key_dict to proper format for async_moderation_hook
|
||||
if call_type == "mcp_call":
|
||||
user_api_key_auth_dict = (
|
||||
self._convert_user_api_key_auth_to_dict(user_api_key_dict)
|
||||
user_api_key_auth_dict = self._convert_user_api_key_auth_to_dict(
|
||||
user_api_key_dict
|
||||
)
|
||||
else:
|
||||
user_api_key_auth_dict = user_api_key_dict
|
||||
@@ -1475,22 +1449,29 @@ class ProxyLogging:
|
||||
):
|
||||
continue
|
||||
|
||||
guardrail_response: Optional[Any] = None
|
||||
if "apply_guardrail" in type(callback).__dict__:
|
||||
data["guardrail_to_apply"] = callback
|
||||
response = await unified_guardrail.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
guardrail_response = (
|
||||
await unified_guardrail.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
)
|
||||
else:
|
||||
response = await callback.async_post_call_success_hook(
|
||||
guardrail_response = await callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=response,
|
||||
)
|
||||
|
||||
if guardrail_response is not None:
|
||||
response = guardrail_response
|
||||
|
||||
############ Handle CustomLogger ###############################
|
||||
#################################################################
|
||||
|
||||
for callback in other_callbacks:
|
||||
await callback.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict, data=data, response=response
|
||||
|
||||
@@ -117,5 +117,4 @@ async def test_chat_completion_check_otel_spans():
|
||||
assert "postgres" in parent_trace_spans
|
||||
assert "redis" in parent_trace_spans
|
||||
assert "raw_gen_ai_request" in parent_trace_spans
|
||||
assert "litellm_request" in parent_trace_spans
|
||||
assert "batch_write_to_db" in parent_trace_spans
|
||||
|
||||
Reference in New Issue
Block a user