[Feat] Add support for Batch API Rate limiting - PR1 adds support for input based rate limits (#16075)

* add count_input_file_usage

* add count_input_file_usage

* fix count_input_file_usage

* _get_batch_job_input_file_usage

* fixes imports

* use _get_batch_job_input_file_usage

* test_batch_rate_limits

* add _check_and_increment_batch_counters

* add get_rate_limiter_for_call_type

* test_batch_rate_limit_multiple_requests

* fixes for batch limits

* fix linting

* fix MYPY linting
This commit is contained in:
Ishaan Jaff
2025-10-29 18:28:52 -07:00
committed by GitHub
parent 8a7f39daa4
commit aea78b8d1a
8 changed files with 881 additions and 12 deletions

4
batch_small.jsonl Normal file
View File

@@ -0,0 +1,4 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}

View File

@@ -1,10 +1,15 @@
import json
from typing import Any, List, Literal, Tuple, Optional
import time
from typing import Any, List, Literal, Optional, Tuple
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.types.llms.openai import Batch
from litellm.types.utils import CallTypes, Usage
from litellm.types.utils import CallTypes, ModelResponse, Usage
from litellm.utils import token_counter
async def calculate_batch_cost_and_usage(
@@ -107,6 +112,10 @@ def calculate_vertex_ai_batch_cost_and_usage(
"""
Calculate both cost and usage from Vertex AI batch responses
"""
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
total_cost = 0.0
total_tokens = 0
prompt_tokens = 0
@@ -115,13 +124,6 @@ def calculate_vertex_ai_batch_cost_and_usage(
for response in vertex_ai_batch_responses:
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
# Transform Vertex AI response to OpenAI format if needed
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
from litellm import ModelResponse
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.utils import CallTypes
from litellm._uuid import uuid
import httpx
import time
# Create required arguments for the transformation method
model_response = ModelResponse()
@@ -163,8 +165,9 @@ def calculate_vertex_ai_batch_cost_and_usage(
total_cost += cost
# Extract usage from the transformed response
if hasattr(openai_format_response, 'usage') and openai_format_response.usage:
usage = openai_format_response.usage
usage_obj = getattr(openai_format_response, 'usage', None)
if usage_obj:
usage = usage_obj
else:
# Fallback: create usage from response dict
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
@@ -278,6 +281,33 @@ def _get_batch_job_total_usage_from_file_content(
completion_tokens=completion_tokens,
)
def _get_batch_job_input_file_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Count the number of tokens in the input file
Used for batch rate limiting to count the number of tokens in the input file
"""
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
body = _item.get("body", {})
model = body.get("model", model_name or "")
messages = body.get("messages", [])
if messages:
item_tokens = token_counter(model=model, messages=messages)
prompt_tokens += item_tokens
return Usage(
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
"""

View File

@@ -0,0 +1,376 @@
"""
Batch Rate Limiter Hook
This hook implements rate limiting for batch API requests by:
1. Reading batch input files to count requests and estimate tokens at submission
2. Validating actual usage from output files when batches complete
3. Integrating with the existing parallel request limiter infrastructure
## Integration & Calling
This hook is automatically registered and called by the proxy system.
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
Quick summary:
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union
from fastapi import HTTPException
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.batch_utils import (
_get_batch_job_input_file_usage,
_get_file_content_as_dictionary,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
RateLimitDescriptor as _RateLimitDescriptor,
)
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
RateLimitStatus as _RateLimitStatus,
)
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
)
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.router import Router as _Router
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
Router = _Router
ParallelRequestLimiter = _ParallelRequestLimiter
RateLimitStatus = _RateLimitStatus
RateLimitDescriptor = _RateLimitDescriptor
else:
Span = Any
InternalUsageCache = Any
Router = Any
ParallelRequestLimiter = Any
RateLimitStatus = Dict[str, Any]
RateLimitDescriptor = Dict[str, Any]
class BatchFileUsage(BaseModel):
"""
Internal model for batch file usage tracking, used for batch rate limiting
"""
total_tokens: int
request_count: int
class _PROXY_BatchRateLimiter(CustomLogger):
"""
Rate limiter for batch API requests.
Handles rate limiting at two points:
1. Batch submission - reads input file and reserves capacity
2. Batch completion - reads output file and adjusts for actual usage
"""
def __init__(
self,
internal_usage_cache: InternalUsageCache,
parallel_request_limiter: ParallelRequestLimiter,
):
"""
Initialize the batch rate limiter.
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
Args:
internal_usage_cache: Cache for storing rate limit data (auto-injected)
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
"""
self.internal_usage_cache = internal_usage_cache
self.parallel_request_limiter = parallel_request_limiter
def _raise_rate_limit_error(
self,
status: "RateLimitStatus",
descriptors: List["RateLimitDescriptor"],
batch_usage: BatchFileUsage,
limit_type: str,
) -> None:
"""Raise HTTPException for rate limit exceeded."""
from datetime import datetime
# Find the descriptor for this status
descriptor_index = next(
(i for i, d in enumerate(descriptors)
if d.get("key") == status.get("descriptor_key")),
0
)
descriptor: RateLimitDescriptor = descriptors[descriptor_index] if descriptors else {"key": "", "value": "", "rate_limit": None}
now = datetime.now().timestamp()
window_size = self.parallel_request_limiter.window_size
reset_time = now + window_size
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
"%Y-%m-%d %H:%M:%S UTC"
)
remaining_display = max(0, status["limit_remaining"])
current_limit = status["current_limit"]
if limit_type == "requests":
detail = (
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
f"out of {current_limit} RPM limit. "
f"Limit resets at: {reset_time_formatted}"
)
else: # tokens
detail = (
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
f"out of {current_limit} TPM limit. "
f"Limit resets at: {reset_time_formatted}"
)
raise HTTPException(
status_code=429,
detail=detail,
headers={
"retry-after": str(window_size),
"rate_limit_type": limit_type,
"reset_at": reset_time_formatted,
},
)
async def _check_and_increment_batch_counters(
self,
user_api_key_dict: UserAPIKeyAuth,
data: Dict,
batch_usage: BatchFileUsage,
) -> None:
"""
Check rate limits and increment counters by the batch amounts.
Raises HTTPException if any limit would be exceeded.
"""
from litellm.types.caching import RedisPipelineIncrementOperation
# Create descriptors and check if batch would exceed limits
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
user_api_key_dict=user_api_key_dict,
data=data,
rpm_limit_type=None,
tpm_limit_type=None,
model_has_failures=False,
)
# Check current usage without incrementing
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
descriptors=descriptors,
parent_otel_span=user_api_key_dict.parent_otel_span,
read_only=True,
)
# Verify batch won't exceed any limits
for status in rate_limit_response["statuses"]:
rate_limit_type = status["rate_limit_type"]
limit_remaining = status["limit_remaining"]
required_capacity = (
batch_usage.request_count if rate_limit_type == "requests"
else batch_usage.total_tokens if rate_limit_type == "tokens"
else 0
)
if required_capacity > limit_remaining:
self._raise_rate_limit_error(
status, descriptors, batch_usage, rate_limit_type
)
# Build pipeline operations for batch increments
# Reuse the same keys that descriptors check
pipeline_operations: List[RedisPipelineIncrementOperation] = []
for descriptor in descriptors:
key = descriptor["key"]
value = descriptor["value"]
rate_limit = descriptor.get("rate_limit")
if rate_limit is None:
continue
# Add RPM increment if limit is set
if rate_limit.get("requests_per_unit") is not None:
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
key=key, value=value, rate_limit_type="requests"
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=rpm_key,
increment_value=batch_usage.request_count,
ttl=self.parallel_request_limiter.window_size,
)
)
# Add TPM increment if limit is set
if rate_limit.get("tokens_per_unit") is not None:
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
key=key, value=value, rate_limit_type="tokens"
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=tpm_key,
increment_value=batch_usage.total_tokens,
ttl=self.parallel_request_limiter.window_size,
)
)
# Execute increments
if pipeline_operations:
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
pipeline_operations=pipeline_operations,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
async def count_input_file_usage(
self,
file_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> BatchFileUsage:
"""
Count number of requests and tokens in a batch input file.
Args:
file_id: The file ID to read
custom_llm_provider: The custom LLM provider to use for token encoding
Returns:
BatchFileUsage with total_tokens and request_count
"""
try:
# Read file content
file_content = await litellm.afile_content(
file_id=file_id,
custom_llm_provider=custom_llm_provider,
)
file_content_as_dict = _get_file_content_as_dictionary(
file_content.content
)
input_file_usage = _get_batch_job_input_file_usage(
file_content_dictionary=file_content_as_dict,
custom_llm_provider=custom_llm_provider,
)
request_count = len(file_content_as_dict)
return BatchFileUsage(
total_tokens=input_file_usage.total_tokens,
request_count=request_count,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error counting input file usage for {file_id}: {str(e)}"
)
raise
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: Any,
data: Dict,
call_type: str,
) -> Union[Exception, str, Dict, None]:
"""
Pre-call hook for batch operations.
Only handles batch creation (acreate_batch):
- Reads input file
- Counts tokens and requests
- Reserves rate limit capacity via parallel_request_limiter
Args:
user_api_key_dict: User authentication information
cache: Cache instance (not used directly)
data: Request data
call_type: Type of call being made
Returns:
Modified data dict or None
Raises:
HTTPException: 429 if rate limit would be exceeded
"""
# Only handle batch creation
if call_type != "acreate_batch":
verbose_proxy_logger.debug(
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
)
return data
verbose_proxy_logger.debug(
"Batch rate limiter: Handling batch creation rate limiting"
)
try:
# Extract input_file_id from data
input_file_id = data.get("input_file_id")
if not input_file_id:
verbose_proxy_logger.debug(
"No input_file_id in batch request, skipping rate limiting"
)
return data
# Get custom_llm_provider for token counting
custom_llm_provider = data.get("custom_llm_provider", "openai")
# Count tokens and requests from input file
verbose_proxy_logger.debug(
f"Counting tokens from batch input file: {input_file_id}"
)
batch_usage = await self.count_input_file_usage(
file_id=input_file_id,
custom_llm_provider=custom_llm_provider,
)
verbose_proxy_logger.debug(
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
f"Requests: {batch_usage.request_count}"
)
# Store batch usage in data for later reference
data["_batch_token_count"] = batch_usage.total_tokens
data["_batch_request_count"] = batch_usage.request_count
# Directly increment counters by batch amounts (check happens atomically)
# This will raise HTTPException if limits are exceeded
await self._check_and_increment_batch_counters(
user_api_key_dict=user_api_key_dict,
data=data,
batch_usage=batch_usage,
)
verbose_proxy_logger.debug("Batch rate limit check passed, counters incremented")
return data
except HTTPException:
# Re-raise HTTP exceptions (rate limit exceeded)
raise
except Exception as e:
verbose_proxy_logger.error(
f"Error in batch rate limiting: {str(e)}", exc_info=True
)
# Don't block the request if rate limiting fails
return data

View File

@@ -162,6 +162,27 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
# Batch rate limiter (lazy loaded)
self._batch_rate_limiter: Optional[Any] = None
def _get_batch_rate_limiter(self) -> Optional[Any]:
"""Get or lazy-load the batch rate limiter."""
if self._batch_rate_limiter is None:
try:
from litellm.proxy.hooks.batch_rate_limiter import (
_PROXY_BatchRateLimiter,
)
self._batch_rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=self.internal_usage_cache,
parallel_request_limiter=self,
)
except Exception as e:
verbose_proxy_logger.debug(
f"Could not load batch rate limiter: {str(e)}"
)
return self._batch_rate_limiter
def _get_current_time(self) -> datetime:
"""Return the current time for rate limiting calculations."""
return self._time_provider()
@@ -987,6 +1008,13 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
# Fail safe: enforce limits if we can't check
return True
def get_rate_limiter_for_call_type(self, call_type: str) -> Optional[Any]:
"""Get the rate limiter for the call type."""
if call_type == "acreate_batch":
batch_limiter = self._get_batch_rate_limiter()
return batch_limiter
return None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
@@ -1000,6 +1028,19 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
"""
verbose_proxy_logger.debug("Inside Rate Limit Pre-Call Hook")
#########################################################
# Check if the call type has a specific rate limiter
# eg. for Batch APIs we need to use the batch rate limiter to read the input file and count the tokens and requests
#########################################################
call_type_specific_rate_limiter = self.get_rate_limiter_for_call_type(call_type=call_type)
if call_type_specific_rate_limiter:
return await call_type_specific_rate_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
)
# Get rate limit types from metadata
metadata = user_api_key_dict.metadata or {}
rpm_limit_type = metadata.get("rpm_limit_type")
@@ -1470,6 +1511,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
f"Error in rate limit failure event: {str(e)}"
)
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):

View File

@@ -22,5 +22,13 @@ search_tools:
search_provider: exa_ai
api_key: os.environ/EXA_API_KEY
# for /files endpoints
files_settings:
- custom_llm_provider: openai
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
callbacks: ["datadog"]

View File

@@ -0,0 +1,14 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello, how are you?"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "What is the weather today?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Tell me a short joke"}]}}

View File

@@ -0,0 +1,391 @@
"""
Integration Tests for Batch Rate Limits
"""
import asyncio
import json
import os
import sys
import pytest
from fastapi import HTTPException
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.batch_rate_limiter import (
BatchFileUsage,
_PROXY_BatchRateLimiter,
)
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
_PROXY_MaxParallelRequestsHandler_v3,
)
from litellm.proxy.utils import InternalUsageCache
def get_expected_batch_file_usage(file_path: str) -> tuple[int, int]:
"""
Helper function to calculate expected request count and token count from a batch JSONL file.
Returns:
tuple[int, int]: (expected_request_count, expected_total_tokens)
"""
with open(file_path, 'r') as f:
file_contents = [json.loads(line) for line in f if line.strip()]
expected_request_count = len(file_contents)
expected_total_tokens = 0
for item in file_contents:
body = item.get("body", {})
model = body.get("model", "")
messages = body.get("messages", [])
if messages:
item_tokens = litellm.token_counter(model=model, messages=messages)
expected_total_tokens += item_tokens
return expected_request_count, expected_total_tokens
@pytest.mark.asyncio()
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None,
reason="OPENAI_API_KEY not set - skipping integration test"
)
async def test_batch_rate_limits():
"""
Integration test for batch rate limits with real OpenAI API calls.
Tests the full flow: file creation -> token counting -> cleanup
"""
litellm._turn_on_debug()
CUSTOM_LLM_PROVIDER = "openai"
BATCH_LIMITER = _PROXY_BatchRateLimiter(
internal_usage_cache=None,
parallel_request_limiter=None,
)
file_name = "openai_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
# Create file on OpenAI
print(f"Creating file from {file_path}")
file_obj = await litellm.acreate_file(
file=open(file_path, "rb"),
purpose="batch",
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Response from creating file: {file_obj}")
assert file_obj.id is not None, "File ID should not be None"
# Give API a moment to process the file
await asyncio.sleep(1)
# Count requests and token usage in input file
tracked_batch_file_usage: BatchFileUsage = await BATCH_LIMITER.count_input_file_usage(
file_id=file_obj.id,
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Actual total tokens: {tracked_batch_file_usage.total_tokens}")
print(f"Actual request count: {tracked_batch_file_usage.request_count}")
# Calculate expected values by reading the JSONL file
expected_request_count, expected_total_tokens = get_expected_batch_file_usage(file_path=file_path)
print(f"Expected request count: {expected_request_count}")
print(f"Expected total tokens: {expected_total_tokens}")
# Verify token counting results
assert tracked_batch_file_usage.request_count == expected_request_count, f"Expected {expected_request_count} requests, got {tracked_batch_file_usage.request_count}"
assert tracked_batch_file_usage.total_tokens == expected_total_tokens, f"Expected {expected_total_tokens} total_tokens, got {tracked_batch_file_usage.total_tokens}"
@pytest.mark.asyncio()
async def test_batch_rate_limit_single_file():
"""
Test batch rate limiting with a single file.
Key has TPM = 200
- File with < 200 tokens: should go through
- File with > 200 tokens: should hit rate limit
"""
import tempfile
CUSTOM_LLM_PROVIDER = "openai"
# Setup: Create internal usage cache and rate limiter
dual_cache = DualCache()
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
internal_usage_cache=internal_usage_cache
)
# Setup: Get batch rate limiter
batch_limiter = rate_limiter._get_batch_rate_limiter()
assert batch_limiter is not None, "Batch rate limiter should be available"
# Setup: Create user API key with TPM = 200
user_api_key_dict = UserAPIKeyAuth(
api_key="test-key-123",
tpm_limit=200,
rpm_limit=10,
)
# Test 1: File with < 200 tokens should go through
print("\n=== Test 1: File under 200 tokens ===")
# Create a small batch file with ~150 tokens
small_batch_content = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hi"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey"}]}}"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write(small_batch_content)
small_file_path = f.name
try:
# Upload file to OpenAI
file_obj_small = await litellm.acreate_file(
file=open(small_file_path, "rb"),
purpose="batch",
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Created small file: {file_obj_small.id}")
await asyncio.sleep(1) # Give API time to process
data_under_limit = {
"model": "gpt-3.5-turbo",
"input_file_id": file_obj_small.id,
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
}
# Should not raise an exception
result = await batch_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=dual_cache,
data=data_under_limit,
call_type="acreate_batch",
)
print(f"✓ File with ~150 tokens passed (under limit of 200)")
print(f" Actual tokens: {result.get('_batch_token_count')}")
except HTTPException as e:
pytest.fail(f"Should not have hit rate limit with small file: {e.detail}")
finally:
os.unlink(small_file_path)
# Test 2: File with > 200 tokens should hit rate limit
print("\n=== Test 2: File over 200 tokens ===")
# Reset cache for clean test
dual_cache = DualCache()
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
internal_usage_cache=internal_usage_cache
)
batch_limiter = rate_limiter._get_batch_rate_limiter()
# Create a larger batch file with ~10000+ tokens (100x larger to ensure it exceeds 200 token limit)
base_message = "This is a longer message that will consume more tokens from the rate limit. " * 100
# Build JSONL content with json.dumps to avoid f-string nesting issues
import json as json_lib
requests = []
for i in range(1, 4):
request_obj = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": base_message}]
}
}
requests.append(json_lib.dumps(request_obj))
large_batch_content = "\n".join(requests)
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write(large_batch_content)
large_file_path = f.name
try:
# Upload file to OpenAI
file_obj_large = await litellm.acreate_file(
file=open(large_file_path, "rb"),
purpose="batch",
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Created large file: {file_obj_large.id}")
await asyncio.sleep(1) # Give API time to process
data_over_limit = {
"model": "gpt-3.5-turbo",
"input_file_id": file_obj_large.id,
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
}
# Should raise HTTPException with 429 status
with pytest.raises(HTTPException) as exc_info:
await batch_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=dual_cache,
data=data_over_limit,
call_type="acreate_batch",
)
assert exc_info.value.status_code == 429, "Should return 429 status code"
assert "tokens" in exc_info.value.detail.lower(), "Error message should mention tokens"
print(f"✓ File with 250+ tokens correctly rejected (over limit of 200)")
print(f" Error: {exc_info.value.detail}")
finally:
os.unlink(large_file_path)
@pytest.mark.asyncio()
async def test_batch_rate_limit_multiple_requests():
"""
Test batch rate limiting with multiple requests.
Key has TPM = 200
- Request 1: file with ~100 tokens (should go through, 100/200 used)
- Request 2: file with ~105 tokens (should hit limit, 100+105=205 > 200)
"""
import tempfile
CUSTOM_LLM_PROVIDER = "openai"
# Setup: Create internal usage cache and rate limiter
dual_cache = DualCache()
internal_usage_cache = InternalUsageCache(dual_cache=dual_cache)
rate_limiter = _PROXY_MaxParallelRequestsHandler_v3(
internal_usage_cache=internal_usage_cache
)
# Setup: Get batch rate limiter
batch_limiter = rate_limiter._get_batch_rate_limiter()
assert batch_limiter is not None, "Batch rate limiter should be available"
# Setup: Create user API key with TPM = 200
user_api_key_dict = UserAPIKeyAuth(
api_key="test-key-456",
tpm_limit=200,
rpm_limit=10,
)
# Request 1: File with ~100 tokens
print("\n=== Request 1: File with ~100 tokens ===")
# Create file with ~100 tokens
import json as json_lib
message_1 = "This message has some content to reach about 100 tokens total. " * 4
requests_1 = []
for i in range(1, 3):
request_obj = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message_1}]
}
}
requests_1.append(json_lib.dumps(request_obj))
batch_content_1 = "\n".join(requests_1)
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write(batch_content_1)
file_path_1 = f.name
try:
# Upload file to OpenAI
file_obj_1 = await litellm.acreate_file(
file=open(file_path_1, "rb"),
purpose="batch",
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Created file 1: {file_obj_1.id}")
await asyncio.sleep(1) # Give API time to process
data_request1 = {
"model": "gpt-3.5-turbo",
"input_file_id": file_obj_1.id,
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
}
# Should not raise an exception
result1 = await batch_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=dual_cache,
data=data_request1,
call_type="acreate_batch",
)
tokens_used_1 = result1.get('_batch_token_count', 0)
print(f"✓ Request 1 with {tokens_used_1} tokens passed ({tokens_used_1}/200 used)")
except HTTPException as e:
pytest.fail(f"Request 1 should not have hit rate limit: {e.detail}")
finally:
os.unlink(file_path_1)
# Request 2: File with ~105+ tokens (total would exceed 200)
print("\n=== Request 2: File with ~105 tokens (should hit limit) ===")
# Create file with ~105+ tokens
message_2 = "This is another message with more content to exceed the remaining limit. " * 11
requests_2 = []
for i in range(1, 3):
request_obj = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": message_2}]
}
}
requests_2.append(json_lib.dumps(request_obj))
batch_content_2 = "\n".join(requests_2)
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
f.write(batch_content_2)
file_path_2 = f.name
try:
# Upload file to OpenAI
file_obj_2 = await litellm.acreate_file(
file=open(file_path_2, "rb"),
purpose="batch",
custom_llm_provider=CUSTOM_LLM_PROVIDER,
)
print(f"Created file 2: {file_obj_2.id}")
await asyncio.sleep(1) # Give API time to process
data_request2 = {
"model": "gpt-3.5-turbo",
"input_file_id": file_obj_2.id,
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
}
# Should raise HTTPException with 429 status
with pytest.raises(HTTPException) as exc_info:
await batch_limiter.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=dual_cache,
data=data_request2,
call_type="acreate_batch",
)
assert exc_info.value.status_code == 429, "Should return 429 status code"
assert "tokens" in exc_info.value.detail.lower(), "Error message should mention tokens"
print(f"✓ Request 2 correctly rejected")
print(f" Error: {exc_info.value.detail}")
finally:
os.unlink(file_path_2)

View File

@@ -0,0 +1,4 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}]}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "What is 2+2?"}]}}
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4", "messages": [{"role": "user", "content": "Tell me a joke about programming"}]}}