mirror of
https://github.com/BerriAI/litellm.git
synced 2025-12-06 11:33:26 +08:00
[Feat] Add support for Batch API Rate limiting - PR1 adds support for input based rate limits (#16075)
* add count_input_file_usage * add count_input_file_usage * fix count_input_file_usage * _get_batch_job_input_file_usage * fixes imports * use _get_batch_job_input_file_usage * test_batch_rate_limits * add _check_and_increment_batch_counters * add get_rate_limiter_for_call_type * test_batch_rate_limit_multiple_requests * fixes for batch limits * fix linting * fix MYPY linting
This commit is contained in:
4
batch_small.jsonl
Normal file
4
batch_small.jsonl
Normal file
@@ -0,0 +1,4 @@
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import json
|
||||
from typing import Any, List, Literal, Tuple, Optional
|
||||
import time
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import CallTypes, Usage
|
||||
from litellm.types.utils import CallTypes, ModelResponse, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
||||
|
||||
async def calculate_batch_cost_and_usage(
|
||||
@@ -107,6 +112,10 @@ def calculate_vertex_ai_batch_cost_and_usage(
|
||||
"""
|
||||
Calculate both cost and usage from Vertex AI batch responses
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
total_cost = 0.0
|
||||
total_tokens = 0
|
||||
prompt_tokens = 0
|
||||
@@ -115,13 +124,6 @@ def calculate_vertex_ai_batch_cost_and_usage(
|
||||
for response in vertex_ai_batch_responses:
|
||||
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
|
||||
# Transform Vertex AI response to OpenAI format if needed
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||
from litellm import ModelResponse
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.utils import CallTypes
|
||||
from litellm._uuid import uuid
|
||||
import httpx
|
||||
import time
|
||||
|
||||
# Create required arguments for the transformation method
|
||||
model_response = ModelResponse()
|
||||
@@ -163,8 +165,9 @@ def calculate_vertex_ai_batch_cost_and_usage(
|
||||
total_cost += cost
|
||||
|
||||
# Extract usage from the transformed response
|
||||
if hasattr(openai_format_response, 'usage') and openai_format_response.usage:
|
||||
usage = openai_format_response.usage
|
||||
usage_obj = getattr(openai_format_response, 'usage', None)
|
||||
if usage_obj:
|
||||
usage = usage_obj
|
||||
else:
|
||||
# Fallback: create usage from response dict
|
||||
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
|
||||
@@ -278,6 +281,33 @@ def _get_batch_job_total_usage_from_file_content(
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def _get_batch_job_input_file_usage(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Count the number of tokens in the input file
|
||||
|
||||
Used for batch rate limiting to count the number of tokens in the input file
|
||||
"""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
for _item in file_content_dictionary:
|
||||
body = _item.get("body", {})
|
||||
model = body.get("model", model_name or "")
|
||||
messages = body.get("messages", [])
|
||||
|
||||
if messages:
|
||||
item_tokens = token_counter(model=model, messages=messages)
|
||||
prompt_tokens += item_tokens
|
||||
|
||||
return Usage(
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||
"""
|
||||
|
||||
376
litellm/proxy/hooks/batch_rate_limiter.py
Normal file
376
litellm/proxy/hooks/batch_rate_limiter.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
Batch Rate Limiter Hook
|
||||
|
||||
This hook implements rate limiting for batch API requests by:
|
||||
1. Reading batch input files to count requests and estimate tokens at submission
|
||||
2. Validating actual usage from output files when batches complete
|
||||
3. Integrating with the existing parallel request limiter infrastructure
|
||||
|
||||
## Integration & Calling
|
||||
This hook is automatically registered and called by the proxy system.
|
||||
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
|
||||
|
||||
Quick summary:
|
||||
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
|
||||
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
|
||||
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
|
||||
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.batch_utils import (
|
||||
_get_batch_job_input_file_usage,
|
||||
_get_file_content_as_dictionary,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
RateLimitDescriptor as _RateLimitDescriptor,
|
||||
)
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
RateLimitStatus as _RateLimitStatus,
|
||||
)
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
|
||||
)
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
Router = _Router
|
||||
ParallelRequestLimiter = _ParallelRequestLimiter
|
||||
RateLimitStatus = _RateLimitStatus
|
||||
RateLimitDescriptor = _RateLimitDescriptor
|
||||
else:
|
||||
Span = Any
|
||||
InternalUsageCache = Any
|
||||
Router = Any
|
||||
ParallelRequestLimiter = Any
|
||||
RateLimitStatus = Dict[str, Any]
|
||||
RateLimitDescriptor = Dict[str, Any]
|
||||
|
||||
class BatchFileUsage(BaseModel):
|
||||
"""
|
||||
Internal model for batch file usage tracking, used for batch rate limiting
|
||||
"""
|
||||
total_tokens: int
|
||||
request_count: int
|
||||
|
||||
class _PROXY_BatchRateLimiter(CustomLogger):
|
||||
"""
|
||||
Rate limiter for batch API requests.
|
||||
|
||||
Handles rate limiting at two points:
|
||||
1. Batch submission - reads input file and reserves capacity
|
||||
2. Batch completion - reads output file and adjusts for actual usage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
internal_usage_cache: InternalUsageCache,
|
||||
parallel_request_limiter: ParallelRequestLimiter,
|
||||
):
|
||||
"""
|
||||
Initialize the batch rate limiter.
|
||||
|
||||
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
|
||||
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
|
||||
|
||||
Args:
|
||||
internal_usage_cache: Cache for storing rate limit data (auto-injected)
|
||||
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
|
||||
"""
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.parallel_request_limiter = parallel_request_limiter
|
||||
|
||||
def _raise_rate_limit_error(
|
||||
self,
|
||||
status: "RateLimitStatus",
|
||||
descriptors: List["RateLimitDescriptor"],
|
||||
batch_usage: BatchFileUsage,
|
||||
limit_type: str,
|
||||
) -> None:
|
||||
"""Raise HTTPException for rate limit exceeded."""
|
||||
from datetime import datetime
|
||||
|
||||
# Find the descriptor for this status
|
||||
descriptor_index = next(
|
||||
(i for i, d in enumerate(descriptors)
|
||||
if d.get("key") == status.get("descriptor_key")),
|
||||
0
|
||||
)
|
||||
descriptor: RateLimitDescriptor = descriptors[descriptor_index] if descriptors else {"key": "", "value": "", "rate_limit": None}
|
||||
|
||||
now = datetime.now().timestamp()
|
||||
window_size = self.parallel_request_limiter.window_size
|
||||
reset_time = now + window_size
|
||||
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
|
||||
"%Y-%m-%d %H:%M:%S UTC"
|
||||
)
|
||||
|
||||
remaining_display = max(0, status["limit_remaining"])
|
||||
current_limit = status["current_limit"]
|
||||
|
||||
if limit_type == "requests":
|
||||
detail = (
|
||||
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||||
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
|
||||
f"out of {current_limit} RPM limit. "
|
||||
f"Limit resets at: {reset_time_formatted}"
|
||||
)
|
||||
else: # tokens
|
||||
detail = (
|
||||
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
|
||||
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
|
||||
f"out of {current_limit} TPM limit. "
|
||||
f"Limit resets at: {reset_time_formatted}"
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=detail,
|
||||
headers={
|
||||
"retry-after": str(window_size),
|
||||
"rate_limit_type": limit_type,
|
||||
"reset_at": reset_time_formatted,
|
||||
},
|
||||
)
|
||||
|
||||
async def _check_and_increment_batch_counters(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: Dict,
|
||||
batch_usage: BatchFileUsage,
|
||||
) -> None:
|
||||
"""
|
||||
Check rate limits and increment counters by the batch amounts.
|
||||
|
||||
Raises HTTPException if any limit would be exceeded.
|
||||
"""
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
# Create descriptors and check if batch would exceed limits
|
||||
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
rpm_limit_type=None,
|
||||
tpm_limit_type=None,
|
||||
model_has_failures=False,
|
||||
)
|
||||
|
||||
# Check current usage without incrementing
|
||||
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
|
||||
descriptors=descriptors,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
read_only=True,
|
||||
)
|
||||
|
||||
# Verify batch won't exceed any limits
|
||||
for status in rate_limit_response["statuses"]:
|
||||
rate_limit_type = status["rate_limit_type"]
|
||||
limit_remaining = status["limit_remaining"]
|
||||
|
||||
required_capacity = (
|
||||
batch_usage.request_count if rate_limit_type == "requests"
|
||||
else batch_usage.total_tokens if rate_limit_type == "tokens"
|
||||
else 0
|
||||
)
|
||||
|
||||
if required_capacity > limit_remaining:
|
||||
self._raise_rate_limit_error(
|
||||
status, descriptors, batch_usage, rate_limit_type
|
||||
)
|
||||
|
||||
# Build pipeline operations for batch increments
|
||||
# Reuse the same keys that descriptors check
|
||||
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
||||
|
||||
for descriptor in descriptors:
|
||||
key = descriptor["key"]
|
||||
value = descriptor["value"]
|
||||
rate_limit = descriptor.get("rate_limit")
|
||||
|
||||
if rate_limit is None:
|
||||
continue
|
||||
|
||||
# Add RPM increment if limit is set
|
||||
if rate_limit.get("requests_per_unit") is not None:
|
||||
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||||
key=key, value=value, rate_limit_type="requests"
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=rpm_key,
|
||||
increment_value=batch_usage.request_count,
|
||||
ttl=self.parallel_request_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Add TPM increment if limit is set
|
||||
if rate_limit.get("tokens_per_unit") is not None:
|
||||
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
|
||||
key=key, value=value, rate_limit_type="tokens"
|
||||
)
|
||||
pipeline_operations.append(
|
||||
RedisPipelineIncrementOperation(
|
||||
key=tpm_key,
|
||||
increment_value=batch_usage.total_tokens,
|
||||
ttl=self.parallel_request_limiter.window_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute increments
|
||||
if pipeline_operations:
|
||||
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
|
||||
pipeline_operations=pipeline_operations,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
)
|
||||
|
||||
async def count_input_file_usage(
|
||||
self,
|
||||
file_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> BatchFileUsage:
|
||||
"""
|
||||
Count number of requests and tokens in a batch input file.
|
||||
|
||||
Args:
|
||||
file_id: The file ID to read
|
||||
custom_llm_provider: The custom LLM provider to use for token encoding
|
||||
|
||||
Returns:
|
||||
BatchFileUsage with total_tokens and request_count
|
||||
"""
|
||||
try:
|
||||
# Read file content
|
||||
file_content = await litellm.afile_content(
|
||||
file_id=file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
file_content_as_dict = _get_file_content_as_dictionary(
|
||||
file_content.content
|
||||
)
|
||||
|
||||
input_file_usage = _get_batch_job_input_file_usage(
|
||||
file_content_dictionary=file_content_as_dict,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
request_count = len(file_content_as_dict)
|
||||
return BatchFileUsage(
|
||||
total_tokens=input_file_usage.total_tokens,
|
||||
request_count=request_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error counting input file usage for {file_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: Any,
|
||||
data: Dict,
|
||||
call_type: str,
|
||||
) -> Union[Exception, str, Dict, None]:
|
||||
"""
|
||||
Pre-call hook for batch operations.
|
||||
|
||||
Only handles batch creation (acreate_batch):
|
||||
- Reads input file
|
||||
- Counts tokens and requests
|
||||
- Reserves rate limit capacity via parallel_request_limiter
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User authentication information
|
||||
cache: Cache instance (not used directly)
|
||||
data: Request data
|
||||
call_type: Type of call being made
|
||||
|
||||
Returns:
|
||||
Modified data dict or None
|
||||
|
||||
Raises:
|
||||
HTTPException: 429 if rate limit would be exceeded
|
||||
"""
|
||||
# Only handle batch creation
|
||||
if call_type != "acreate_batch":
|
||||
verbose_proxy_logger.debug(
|
||||
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
|
||||
)
|
||||
return data
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Batch rate limiter: Handling batch creation rate limiting"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract input_file_id from data
|
||||
input_file_id = data.get("input_file_id")
|
||||
if not input_file_id:
|
||||
verbose_proxy_logger.debug(
|
||||
"No input_file_id in batch request, skipping rate limiting"
|
||||
)
|
||||
return data
|
||||
|
||||
# Get custom_llm_provider for token counting
|
||||
custom_llm_provider = data.get("custom_llm_provider", "openai")
|
||||
|
||||
# Count tokens and requests from input file
|
||||
verbose_proxy_logger.debug(
|
||||
f"Counting tokens from batch input file: {input_file_id}"
|
||||
)
|
||||
batch_usage = await self.count_input_file_usage(
|
||||
file_id=input_file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
|
||||
f"Requests: {batch_usage.request_count}"
|
||||
)
|
||||
|
||||
# Store batch usage in data for later reference
|
||||
data["_batch_token_count"] = batch_usage.total_tokens
|
||||
data["_batch_request_count"] = batch_usage.request_count
|
||||
|
||||
# Directly increment counters by batch amounts (check happens atomically)
|
||||
# This will raise HTTPException if limits are exceeded
|
||||
await self._check_and_increment_batch_counters(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
batch_usage=batch_usage,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("Batch rate limit check passed, counters incremented")
|
||||
return data
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (rate limit exceeded)
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in batch rate limiting: {str(e)}", exc_info=True
|
||||
)
|
||||
# Don't block the request if rate limiting fails
|
||||
return data
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -162,6 +162,27 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
|
||||
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
|
||||
|
||||
# Batch rate limiter (lazy loaded)
|
||||
self._batch_rate_limiter: Optional[Any] = None
|
||||
|
||||
def _get_batch_rate_limiter(self) -> Optional[Any]:
|
||||
"""Get or lazy-load the batch rate limiter."""
|
||||
if self._batch_rate_limiter is None:
|
||||
try:
|
||||
from litellm.proxy.hooks.batch_rate_limiter import (
|
||||
_PROXY_BatchRateLimiter,
|
||||
)
|
||||
|
||||
self._batch_rate_limiter = _PROXY_BatchRateLimiter(
|
||||
internal_usage_cache=self.internal_usage_cache,
|
||||
parallel_request_limiter=self,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not load batch rate limiter: {str(e)}"
|
||||
)
|
||||
return self._batch_rate_limiter
|
||||
|
||||
def _get_current_time(self) -> datetime:
|
||||
"""Return the current time for rate limiting calculations."""
|
||||
return self._time_provider()
|
||||
@@ -987,6 +1008,13 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
# Fail safe: enforce limits if we can't check
|
||||
return True
|
||||
|
||||
def get_rate_limiter_for_call_type(self, call_type: str) -> Optional[Any]:
|
||||
"""Get the rate limiter for the call type."""
|
||||
if call_type == "acreate_batch":
|
||||
batch_limiter = self._get_batch_rate_limiter()
|
||||
return batch_limiter
|
||||
return None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
@@ -1000,6 +1028,19 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
"""
|
||||
verbose_proxy_logger.debug("Inside Rate Limit Pre-Call Hook")
|
||||
|
||||
#########################################################
|
||||
# Check if the call type has a specific rate limiter
|
||||
# eg. for Batch APIs we need to use the batch rate limiter to read the input file and count the tokens and requests
|
||||
#########################################################
|
||||
call_type_specific_rate_limiter = self.get_rate_limiter_for_call_type(call_type=call_type)
|
||||
if call_type_specific_rate_limiter:
|
||||
return await call_type_specific_rate_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type=call_type,
|
||||
)
|
||||
|
||||
# Get rate limit types from metadata
|
||||
metadata = user_api_key_dict.metadata or {}
|
||||
rpm_limit_type = metadata.get("rpm_limit_type")
|
||||
@@ -1470,6 +1511,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
f"Error in rate limit failure event: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def async_post_call_success_hook(
|
||||
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
||||
):
|
||||
|
||||
@@ -22,5 +22,13 @@ search_tools:
|
||||
search_provider: exa_ai
|
||||
api_key: os.environ/EXA_API_KEY
|
||||
|
||||
|
||||
|
||||
# for /files endpoints
|
||||
files_settings:
|
||||
- custom_llm_provider: openai
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
|
||||
litellm_settings:
|
||||
callbacks: ["datadog"]
|
||||
14
tests/batches_tests/batch_small.jsonl
Normal file
14
tests/batches_tests/batch_small.jsonl
Normal file
@@ -0,0 +1,14 @@
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
|
||||
|
||||
|
||||
391
tests/batches_tests/test_batch_rate_limits.py
Normal file
391
tests/batches_tests/test_batch_rate_limits.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""
|
||||
Integration Tests for Batch Rate Limits
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.batch_rate_limiter import (
|
||||
BatchFileUsage,
|
||||
_PROXY_BatchRateLimiter,
|
||||
)
|
||||
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||
_PROXY_MaxParallelRequestsHandler_v3,
|
||||
)
|
||||
from litellm.proxy.utils import InternalUsageCache
|
||||
|
||||
|
||||
def get_expected_batch_file_usage(file_path: str) -> tuple[int, int]:
|
||||
"""
|
||||
Helper function to calculate expected request count and token count from a batch JSONL file.
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (expected_request_count, expected_total_tokens)
|
||||
"""
|
||||
with open(file_path, 'r') as f:
|
||||
file_contents = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
expected_request_count = len(file_contents)
|
||||
expected_total_tokens = 0
|
||||
|
||||
for item in file_contents:
|
||||
body = item.get("body", {})
|
||||
model = body.get("model", "")
|
||||
messages = body.get("messages", [])
|
||||
if messages:
|
||||
item_tokens = litellm.token_counter(model=model, messages=messages)
|
||||
expected_total_tokens += item_tokens
|
||||
|
||||
return expected_request_count, expected_total_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None,
|
||||
reason="OPENAI_API_KEY not set - skipping integration test"
|
||||
)
|
||||
async def test_batch_rate_limits():
|
||||
"""
|
||||
Integration test for batch rate limits with real OpenAI API calls.
|
||||
Tests the full flow: file creation -> token counting -> cleanup
|
||||
"""
|
||||
litellm._turn_on_debug()
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
BATCH_LIMITER = _PROXY_BatchRateLimiter(
|
||||
internal_usage_cache=None,
|
||||
parallel_request_limiter=None,
|
||||
)
|
||||
|
||||
file_name = "openai_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
# Create file on OpenAI
|
||||
print(f"Creating file from {file_path}")
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Response from creating file: {file_obj}")
|
||||
|
||||
|
||||
assert file_obj.id is not None, "File ID should not be None"
|
||||
|
||||
# Give API a moment to process the file
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
# Count requests and token usage in input file
|
||||
tracked_batch_file_usage: BatchFileUsage = await BATCH_LIMITER.count_input_file_usage(
|
||||
file_id=file_obj.id,
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Actual total tokens: {tracked_batch_file_usage.total_tokens}")
|
||||
print(f"Actual request count: {tracked_batch_file_usage.request_count}")
|
||||
|
||||
# Calculate expected values by reading the JSONL file
|
||||
expected_request_count, expected_total_tokens = get_expected_batch_file_usage(file_path=file_path)
|
||||
|
||||
print(f"Expected request count: {expected_request_count}")
|
||||
print(f"Expected total tokens: {expected_total_tokens}")
|
||||
|
||||
# Verify token counting results
|
||||
assert tracked_batch_file_usage.request_count == expected_request_count, f"Expected {expected_request_count} requests, got {tracked_batch_file_usage.request_count}"
|
||||
assert tracked_batch_file_usage.total_tokens == expected_total_tokens, f"Expected {expected_total_tokens} total_tokens, got {tracked_batch_file_usage.total_tokens}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_batch_rate_limit_single_file():
|
||||
"""
|
||||
Test batch rate limiting with a single file.
|
||||
|
||||
Key has TPM = 200
|
||||
- File with < 200 tokens: should go through
|
||||
- File with > 200 tokens: should hit rate limit
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
|
||||
# Setup: Create internal usage cache and rate limiter
|
||||
dual_cache = DualCache()
|
||||
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
|
||||
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
|
||||
internal_usage_cache=internal_usage_cache
|
||||
)
|
||||
|
||||
# Setup: Get batch rate limiter
|
||||
batch_limiter = rate_limiter._get_batch_rate_limiter()
|
||||
assert batch_limiter is not None, "Batch rate limiter should be available"
|
||||
|
||||
# Setup: Create user API key with TPM = 200
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test-key-123",
|
||||
tpm_limit=200,
|
||||
rpm_limit=10,
|
||||
)
|
||||
|
||||
# Test 1: File with < 200 tokens should go through
|
||||
print("\n=== Test 1: File under 200 tokens ===")
|
||||
|
||||
# Create a small batch file with ~150 tokens
|
||||
small_batch_content = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hi"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey"}]}}"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write(small_batch_content)
|
||||
small_file_path = f.name
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_small = await litellm.acreate_file(
|
||||
file=open(small_file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created small file: {file_obj_small.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_under_limit = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_small.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should not raise an exception
|
||||
result = await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_under_limit,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
print(f"✓ File with ~150 tokens passed (under limit of 200)")
|
||||
print(f" Actual tokens: {result.get('_batch_token_count')}")
|
||||
except HTTPException as e:
|
||||
pytest.fail(f"Should not have hit rate limit with small file: {e.detail}")
|
||||
finally:
|
||||
os.unlink(small_file_path)
|
||||
|
||||
# Test 2: File with > 200 tokens should hit rate limit
|
||||
print("\n=== Test 2: File over 200 tokens ===")
|
||||
|
||||
# Reset cache for clean test
|
||||
dual_cache = DualCache()
|
||||
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
|
||||
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
|
||||
internal_usage_cache=internal_usage_cache
|
||||
)
|
||||
batch_limiter = rate_limiter._get_batch_rate_limiter()
|
||||
|
||||
# Create a larger batch file with ~10000+ tokens (100x larger to ensure it exceeds 200 token limit)
|
||||
base_message = "This is a longer message that will consume more tokens from the rate limit. " * 100
|
||||
|
||||
# Build JSONL content with json.dumps to avoid f-string nesting issues
|
||||
import json as json_lib
|
||||
requests = []
|
||||
for i in range(1, 4):
|
||||
request_obj = {
|
||||
"custom_id": f"request-{i}",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": base_message}]
|
||||
}
|
||||
}
|
||||
requests.append(json_lib.dumps(request_obj))
|
||||
|
||||
large_batch_content = "\n".join(requests)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write(large_batch_content)
|
||||
large_file_path = f.name
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_large = await litellm.acreate_file(
|
||||
file=open(large_file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created large file: {file_obj_large.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_over_limit = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_large.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_over_limit,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert "tokens" in exc_info.value.detail.lower(), "Error message should mention tokens"
|
||||
print(f"✓ File with 250+ tokens correctly rejected (over limit of 200)")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
finally:
|
||||
os.unlink(large_file_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_batch_rate_limit_multiple_requests():
|
||||
"""
|
||||
Test batch rate limiting with multiple requests.
|
||||
|
||||
Key has TPM = 200
|
||||
- Request 1: file with ~100 tokens (should go through, 100/200 used)
|
||||
- Request 2: file with ~105 tokens (should hit limit, 100+105=205 > 200)
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
|
||||
# Setup: Create internal usage cache and rate limiter
|
||||
dual_cache = DualCache()
|
||||
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
|
||||
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
|
||||
internal_usage_cache=internal_usage_cache
|
||||
)
|
||||
|
||||
# Setup: Get batch rate limiter
|
||||
batch_limiter = rate_limiter._get_batch_rate_limiter()
|
||||
assert batch_limiter is not None, "Batch rate limiter should be available"
|
||||
|
||||
# Setup: Create user API key with TPM = 200
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test-key-456",
|
||||
tpm_limit=200,
|
||||
rpm_limit=10,
|
||||
)
|
||||
|
||||
# Request 1: File with ~100 tokens
|
||||
print("\n=== Request 1: File with ~100 tokens ===")
|
||||
|
||||
# Create file with ~100 tokens
|
||||
import json as json_lib
|
||||
message_1 = "This message has some content to reach about 100 tokens total. " * 4
|
||||
requests_1 = []
|
||||
for i in range(1, 3):
|
||||
request_obj = {
|
||||
"custom_id": f"request-{i}",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": message_1}]
|
||||
}
|
||||
}
|
||||
requests_1.append(json_lib.dumps(request_obj))
|
||||
|
||||
batch_content_1 = "\n".join(requests_1)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write(batch_content_1)
|
||||
file_path_1 = f.name
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_1 = await litellm.acreate_file(
|
||||
file=open(file_path_1, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created file 1: {file_obj_1.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_request1 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_1.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should not raise an exception
|
||||
result1 = await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_request1,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
tokens_used_1 = result1.get('_batch_token_count', 0)
|
||||
print(f"✓ Request 1 with {tokens_used_1} tokens passed ({tokens_used_1}/200 used)")
|
||||
except HTTPException as e:
|
||||
pytest.fail(f"Request 1 should not have hit rate limit: {e.detail}")
|
||||
finally:
|
||||
os.unlink(file_path_1)
|
||||
|
||||
# Request 2: File with ~105+ tokens (total would exceed 200)
|
||||
print("\n=== Request 2: File with ~105 tokens (should hit limit) ===")
|
||||
|
||||
# Create file with ~105+ tokens
|
||||
message_2 = "This is another message with more content to exceed the remaining limit. " * 11
|
||||
requests_2 = []
|
||||
for i in range(1, 3):
|
||||
request_obj = {
|
||||
"custom_id": f"request-{i}",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": message_2}]
|
||||
}
|
||||
}
|
||||
requests_2.append(json_lib.dumps(request_obj))
|
||||
|
||||
batch_content_2 = "\n".join(requests_2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
|
||||
f.write(batch_content_2)
|
||||
file_path_2 = f.name
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_2 = await litellm.acreate_file(
|
||||
file=open(file_path_2, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created file 2: {file_obj_2.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_request2 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_2.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_request2,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert "tokens" in exc_info.value.detail.lower(), "Error message should mention tokens"
|
||||
print(f"✓ Request 2 correctly rejected")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
finally:
|
||||
os.unlink(file_path_2)
|
||||
@@ -0,0 +1,4 @@
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}]}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "What is 2+2?"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Tell me a joke about programming"}]}}
|
||||
|
||||
Reference in New Issue
Block a user