Files
litellm/litellm/batches/batch_utils.py
Ishaan Jaff aea78b8d1a [Feat] Add support for Batch API Rate limiting - PR1 adds support for input based rate limits (#16075)
* add count_input_file_usage

* add count_input_file_usage

* fix count_input_file_usage

* _get_batch_job_input_file_usage

* fixes imports

* use _get_batch_job_input_file_usage

* test_batch_rate_limits

* add _check_and_increment_batch_counters

* add get_rate_limiter_for_call_type

* test_batch_rate_limit_multiple_requests

* fixes for batch limits

* fix linting

* fix MYPY linting
2025-10-29 18:28:52 -07:00

336 lines
12 KiB
Python

import json
import time
from typing import Any, List, Literal, Optional, Tuple
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.types.llms.openai import Batch
from litellm.types.utils import CallTypes, ModelResponse, Usage
from litellm.utils import token_counter
async def calculate_batch_cost_and_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
model_name: Optional[str] = None,
) -> Tuple[float, Usage, List[str]]:
"""
Calculate the cost and usage of a batch
"""
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)
return batch_cost, batch_usage, batch_models
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
model_name: Optional[str] = None,
) -> Tuple[float, Usage, List[str]]:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
batch, custom_llm_provider
)
# Calculate costs and usage
batch_cost = _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
model_name=model_name,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
model_name=model_name,
)
batch_models = _get_batch_models_from_file_content(file_content_dictionary, model_name)
return batch_cost, batch_usage, batch_models
def _get_batch_models_from_file_content(
file_content_dictionary: List[dict],
model_name: Optional[str] = None,
) -> List[str]:
"""
Get the models from the file content
"""
if model_name:
return [model_name]
batch_models = []
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
_model = _response_body.get("model")
if _model:
batch_models.append(_model)
return batch_models
def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> float:
"""
Calculate the cost of a batch based on the output file id
"""
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
return batch_cost
# For other providers, use the existing logic
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
def calculate_vertex_ai_batch_cost_and_usage(
vertex_ai_batch_responses: List[dict],
model_name: Optional[str] = None,
) -> Tuple[float, Usage]:
"""
Calculate both cost and usage from Vertex AI batch responses
"""
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
total_cost = 0.0
total_tokens = 0
prompt_tokens = 0
completion_tokens = 0
for response in vertex_ai_batch_responses:
if response.get("status") == "JOB_STATE_SUCCEEDED": # Check if response was successful
# Transform Vertex AI response to OpenAI format if needed
# Create required arguments for the transformation method
model_response = ModelResponse()
# Ensure model_name is not None
actual_model_name = model_name or "gemini-2.5-flash"
# Create a real LiteLLM logging object
logging_obj = Logging(
model=actual_model_name,
messages=[{"role": "user", "content": "batch_request"}],
stream=False,
call_type=CallTypes.aretrieve_batch,
start_time=time.time(),
litellm_call_id="batch_" + str(uuid.uuid4()),
function_id="batch_processing",
litellm_trace_id=str(uuid.uuid4()),
kwargs={"optional_params": {}}
)
# Add the optional_params attribute that the Vertex AI transformation expects
logging_obj.optional_params = {}
raw_response = httpx.Response(200) # Mock response object
openai_format_response = VertexGeminiConfig()._transform_google_generate_content_to_openai_model_response(
completion_response=response["response"],
model_response=model_response,
model=actual_model_name,
logging_obj=logging_obj,
raw_response=raw_response,
)
# Calculate cost using existing function
cost = litellm.completion_cost(
completion_response=openai_format_response,
custom_llm_provider="vertex_ai",
call_type=CallTypes.aretrieve_batch.value,
)
total_cost += cost
# Extract usage from the transformed response
usage_obj = getattr(openai_format_response, 'usage', None)
if usage_obj:
usage = usage_obj
else:
# Fallback: create usage from response dict
response_dict = openai_format_response.dict() if hasattr(openai_format_response, 'dict') else {}
usage = _get_batch_job_usage_from_response_body(response_dict)
total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens
return total_cost, Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
async def _get_batch_output_file_content_as_dictionary(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> List[dict]:
"""
Get the batch output file content as a list of dictionaries
"""
from litellm.files.main import afile_content
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
if batch.output_file_id is None:
raise ValueError("Output file id is None cannot retrieve file content")
_file_content = await afile_content(
file_id=batch.output_file_id,
custom_llm_provider=custom_llm_provider,
)
return _get_file_content_as_dictionary(_file_content.content)
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
"""
Get the file content as a list of dictionaries from JSON Lines format
"""
try:
_file_content_str = file_content.decode("utf-8")
# Split by newlines and parse each line as a separate JSON object
json_objects = []
for line in _file_content_str.strip().split("\n"):
if line: # Skip empty lines
json_objects.append(json.loads(line))
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
return json_objects
except Exception as e:
raise e
def _get_batch_job_cost_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> float:
"""
Get the cost of a batch job from the file content
"""
try:
total_cost: float = 0.0
# parse the file content as json
verbose_logger.debug(
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
)
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
call_type=CallTypes.aretrieve_batch.value,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
except Exception as e:
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
raise e
def _get_batch_job_total_usage_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Get the tokens of a batch job from the file content
"""
# Handle Vertex AI with specialized method
if custom_llm_provider == "vertex_ai" and model_name:
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(file_content_dictionary, model_name)
return batch_usage
# For other providers, use the existing logic
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens
return Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_input_file_usage(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
model_name: Optional[str] = None,
) -> Usage:
"""
Count the number of tokens in the input file
Used for batch rate limiting to count the number of tokens in the input file
"""
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
body = _item.get("body", {})
model = body.get("model", model_name or "")
messages = body.get("messages", [])
if messages:
item_tokens = token_counter(model=model, messages=messages)
prompt_tokens += item_tokens
return Usage(
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
"""
Get the tokens of a batch job from the response body
"""
_usage_dict = response_body.get("usage", None) or {}
usage: Usage = Usage(**_usage_dict)
return usage
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
"""
Get the response from the batch job output file
"""
_response: dict = batch_job_output_file.get("response", None) or {}
_response_body = _response.get("body", None) or {}
return _response_body
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
"""
Check if the batch job response status == 200
"""
_response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200