mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
* add count_input_file_usage * add count_input_file_usage * fix count_input_file_usage * _get_batch_job_input_file_usage * fixes imports * use _get_batch_job_input_file_usage * test_batch_rate_limits * add _check_and_increment_batch_counters * add get_rate_limiter_for_call_type * test_batch_rate_limit_multiple_requests * fixes for batch limits * fix linting * fix MYPY linting
377 lines
14 KiB
Python
377 lines
14 KiB
Python
"""
|
|
Batch Rate Limiter Hook
|
|
|
|
This hook implements rate limiting for batch API requests by:
|
|
1. Reading batch input files to count requests and estimate tokens at submission
|
|
2. Validating actual usage from output files when batches complete
|
|
3. Integrating with the existing parallel request limiter infrastructure
|
|
|
|
## Integration & Calling
|
|
This hook is automatically registered and called by the proxy system.
|
|
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
|
|
|
|
Quick summary:
|
|
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
|
|
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
|
|
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
|
|
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union
|
|
|
|
from fastapi import HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.batches.batch_utils import (
|
|
_get_batch_job_input_file_usage,
|
|
_get_file_content_as_dictionary,
|
|
)
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
|
|
if TYPE_CHECKING:
|
|
from opentelemetry.trace import Span as _Span
|
|
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
|
RateLimitDescriptor as _RateLimitDescriptor,
|
|
)
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
|
RateLimitStatus as _RateLimitStatus,
|
|
)
|
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
|
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
|
|
)
|
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
|
from litellm.router import Router as _Router
|
|
|
|
Span = Union[_Span, Any]
|
|
InternalUsageCache = _InternalUsageCache
|
|
Router = _Router
|
|
ParallelRequestLimiter = _ParallelRequestLimiter
|
|
RateLimitStatus = _RateLimitStatus
|
|
RateLimitDescriptor = _RateLimitDescriptor
|
|
else:
|
|
Span = Any
|
|
InternalUsageCache = Any
|
|
Router = Any
|
|
ParallelRequestLimiter = Any
|
|
RateLimitStatus = Dict[str, Any]
|
|
RateLimitDescriptor = Dict[str, Any]
|
|
|
|
class BatchFileUsage(BaseModel):
|
|
"""
|
|
Internal model for batch file usage tracking, used for batch rate limiting
|
|
"""
|
|
total_tokens: int
|
|
request_count: int
|
|
|
|
class _PROXY_BatchRateLimiter(CustomLogger):
|
|
"""
|
|
Rate limiter for batch API requests.
|
|
|
|
Handles rate limiting at two points:
|
|
1. Batch submission - reads input file and reserves capacity
|
|
2. Batch completion - reads output file and adjusts for actual usage
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
internal_usage_cache: InternalUsageCache,
|
|
parallel_request_limiter: ParallelRequestLimiter,
|
|
):
|
|
"""
|
|
Initialize the batch rate limiter.
|
|
|
|
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
|
|
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
|
|
|
|
Args:
|
|
internal_usage_cache: Cache for storing rate limit data (auto-injected)
|
|
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
|
|
"""
|
|
self.internal_usage_cache = internal_usage_cache
|
|
self.parallel_request_limiter = parallel_request_limiter
|
|
|
|
def _raise_rate_limit_error(
|
|
self,
|
|
status: "RateLimitStatus",
|
|
descriptors: List["RateLimitDescriptor"],
|
|
batch_usage: BatchFileUsage,
|
|
limit_type: str,
|
|
) -> None:
|
|
"""Raise HTTPException for rate limit exceeded."""
|
|
from datetime import datetime
|
|
|
|
# Find the descriptor for this status
|
|
descriptor_index = next(
|
|
(i for i, d in enumerate(descriptors)
|
|
if d.get("key") == status.get("descriptor_key")),
|
|
0
|
|
)
|
|
descriptor: RateLimitDescriptor = descriptors[descriptor_index] if descriptors else {"key": "", "value": "", "rate_limit": None}
|
|
|
|
now = datetime.now().timestamp()
|
|
window_size = self.parallel_request_limiter.window_size
|
|
reset_time = now + window_size
|
|
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
|
|
"%Y-%m-%d %H:%M:%S UTC"
|
|
)
|
|
|
|
remaining_display = max(0, status["limit_remaining"])
|
|
current_limit = status["current_limit"]
|
|
|
|
if limit_type == "requests":
|
|
detail = (
|
|
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
|
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
|
|
f"out of {current_limit} RPM limit. "
|
|
f"Limit resets at: {reset_time_formatted}"
|
|
)
|
|
else: # tokens
|
|
detail = (
|
|
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
|
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
|
|
f"out of {current_limit} TPM limit. "
|
|
f"Limit resets at: {reset_time_formatted}"
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=detail,
|
|
headers={
|
|
"retry-after": str(window_size),
|
|
"rate_limit_type": limit_type,
|
|
"reset_at": reset_time_formatted,
|
|
},
|
|
)
|
|
|
|
async def _check_and_increment_batch_counters(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
data: Dict,
|
|
batch_usage: BatchFileUsage,
|
|
) -> None:
|
|
"""
|
|
Check rate limits and increment counters by the batch amounts.
|
|
|
|
Raises HTTPException if any limit would be exceeded.
|
|
"""
|
|
from litellm.types.caching import RedisPipelineIncrementOperation
|
|
|
|
# Create descriptors and check if batch would exceed limits
|
|
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=data,
|
|
rpm_limit_type=None,
|
|
tpm_limit_type=None,
|
|
model_has_failures=False,
|
|
)
|
|
|
|
# Check current usage without incrementing
|
|
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
|
|
descriptors=descriptors,
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
read_only=True,
|
|
)
|
|
|
|
# Verify batch won't exceed any limits
|
|
for status in rate_limit_response["statuses"]:
|
|
rate_limit_type = status["rate_limit_type"]
|
|
limit_remaining = status["limit_remaining"]
|
|
|
|
required_capacity = (
|
|
batch_usage.request_count if rate_limit_type == "requests"
|
|
else batch_usage.total_tokens if rate_limit_type == "tokens"
|
|
else 0
|
|
)
|
|
|
|
if required_capacity > limit_remaining:
|
|
self._raise_rate_limit_error(
|
|
status, descriptors, batch_usage, rate_limit_type
|
|
)
|
|
|
|
# Build pipeline operations for batch increments
|
|
# Reuse the same keys that descriptors check
|
|
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
|
|
|
for descriptor in descriptors:
|
|
key = descriptor["key"]
|
|
value = descriptor["value"]
|
|
rate_limit = descriptor.get("rate_limit")
|
|
|
|
if rate_limit is None:
|
|
continue
|
|
|
|
# Add RPM increment if limit is set
|
|
if rate_limit.get("requests_per_unit") is not None:
|
|
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
|
key=key, value=value, rate_limit_type="requests"
|
|
)
|
|
pipeline_operations.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=rpm_key,
|
|
increment_value=batch_usage.request_count,
|
|
ttl=self.parallel_request_limiter.window_size,
|
|
)
|
|
)
|
|
|
|
# Add TPM increment if limit is set
|
|
if rate_limit.get("tokens_per_unit") is not None:
|
|
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
|
key=key, value=value, rate_limit_type="tokens"
|
|
)
|
|
pipeline_operations.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=tpm_key,
|
|
increment_value=batch_usage.total_tokens,
|
|
ttl=self.parallel_request_limiter.window_size,
|
|
)
|
|
)
|
|
|
|
# Execute increments
|
|
if pipeline_operations:
|
|
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
|
|
pipeline_operations=pipeline_operations,
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
)
|
|
|
|
async def count_input_file_usage(
|
|
self,
|
|
file_id: str,
|
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
|
) -> BatchFileUsage:
|
|
"""
|
|
Count number of requests and tokens in a batch input file.
|
|
|
|
Args:
|
|
file_id: The file ID to read
|
|
custom_llm_provider: The custom LLM provider to use for token encoding
|
|
|
|
Returns:
|
|
BatchFileUsage with total_tokens and request_count
|
|
"""
|
|
try:
|
|
# Read file content
|
|
file_content = await litellm.afile_content(
|
|
file_id=file_id,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
file_content_as_dict = _get_file_content_as_dictionary(
|
|
file_content.content
|
|
)
|
|
|
|
input_file_usage = _get_batch_job_input_file_usage(
|
|
file_content_dictionary=file_content_as_dict,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
request_count = len(file_content_as_dict)
|
|
return BatchFileUsage(
|
|
total_tokens=input_file_usage.total_tokens,
|
|
request_count=request_count,
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"Error counting input file usage for {file_id}: {str(e)}"
|
|
)
|
|
raise
|
|
|
|
async def async_pre_call_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
cache: Any,
|
|
data: Dict,
|
|
call_type: str,
|
|
) -> Union[Exception, str, Dict, None]:
|
|
"""
|
|
Pre-call hook for batch operations.
|
|
|
|
Only handles batch creation (acreate_batch):
|
|
- Reads input file
|
|
- Counts tokens and requests
|
|
- Reserves rate limit capacity via parallel_request_limiter
|
|
|
|
Args:
|
|
user_api_key_dict: User authentication information
|
|
cache: Cache instance (not used directly)
|
|
data: Request data
|
|
call_type: Type of call being made
|
|
|
|
Returns:
|
|
Modified data dict or None
|
|
|
|
Raises:
|
|
HTTPException: 429 if rate limit would be exceeded
|
|
"""
|
|
# Only handle batch creation
|
|
if call_type != "acreate_batch":
|
|
verbose_proxy_logger.debug(
|
|
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
|
|
)
|
|
return data
|
|
|
|
verbose_proxy_logger.debug(
|
|
"Batch rate limiter: Handling batch creation rate limiting"
|
|
)
|
|
|
|
try:
|
|
# Extract input_file_id from data
|
|
input_file_id = data.get("input_file_id")
|
|
if not input_file_id:
|
|
verbose_proxy_logger.debug(
|
|
"No input_file_id in batch request, skipping rate limiting"
|
|
)
|
|
return data
|
|
|
|
# Get custom_llm_provider for token counting
|
|
custom_llm_provider = data.get("custom_llm_provider", "openai")
|
|
|
|
# Count tokens and requests from input file
|
|
verbose_proxy_logger.debug(
|
|
f"Counting tokens from batch input file: {input_file_id}"
|
|
)
|
|
batch_usage = await self.count_input_file_usage(
|
|
file_id=input_file_id,
|
|
custom_llm_provider=custom_llm_provider,
|
|
)
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
|
|
f"Requests: {batch_usage.request_count}"
|
|
)
|
|
|
|
# Store batch usage in data for later reference
|
|
data["_batch_token_count"] = batch_usage.total_tokens
|
|
data["_batch_request_count"] = batch_usage.request_count
|
|
|
|
# Directly increment counters by batch amounts (check happens atomically)
|
|
# This will raise HTTPException if limits are exceeded
|
|
await self._check_and_increment_batch_counters(
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=data,
|
|
batch_usage=batch_usage,
|
|
)
|
|
|
|
verbose_proxy_logger.debug("Batch rate limit check passed, counters incremented")
|
|
return data
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTP exceptions (rate limit exceeded)
|
|
raise
|
|
except Exception as e:
|
|
verbose_proxy_logger.error(
|
|
f"Error in batch rate limiting: {str(e)}", exc_info=True
|
|
)
|
|
# Don't block the request if rate limiting fails
|
|
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|