Merge pull request #17498 from BerriAI/litellm_customer_usage_backend

[Feature] Customer (end user) Usage
This commit is contained in:
yuneng-jiang
2025-12-05 15:31:08 -08:00
committed by GitHub
16 changed files with 536 additions and 10 deletions

Binary file not shown.

View File

@@ -0,0 +1,42 @@
-- CreateTable
CREATE TABLE "LiteLLM_DailyEndUserSpend" (
"id" TEXT NOT NULL,
"end_user_id" TEXT,
"date" TEXT NOT NULL,
"api_key" TEXT NOT NULL,
"model" TEXT,
"model_group" TEXT,
"custom_llm_provider" TEXT,
"mcp_namespaced_tool_name" TEXT,
"prompt_tokens" BIGINT NOT NULL DEFAULT 0,
"completion_tokens" BIGINT NOT NULL DEFAULT 0,
"cache_read_input_tokens" BIGINT NOT NULL DEFAULT 0,
"cache_creation_input_tokens" BIGINT NOT NULL DEFAULT 0,
"spend" DOUBLE PRECISION NOT NULL DEFAULT 0.0,
"api_requests" BIGINT NOT NULL DEFAULT 0,
"successful_requests" BIGINT NOT NULL DEFAULT 0,
"failed_requests" BIGINT NOT NULL DEFAULT 0,
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
"updated_at" TIMESTAMP(3) NOT NULL,
CONSTRAINT "LiteLLM_DailyEndUserSpend_pkey" PRIMARY KEY ("id")
);
-- CreateIndex
CREATE INDEX "LiteLLM_DailyEndUserSpend_date_idx" ON "LiteLLM_DailyEndUserSpend"("date");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyEndUserSpend_end_user_id_idx" ON "LiteLLM_DailyEndUserSpend"("end_user_id");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyEndUserSpend_api_key_idx" ON "LiteLLM_DailyEndUserSpend"("api_key");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyEndUserSpend_model_idx" ON "LiteLLM_DailyEndUserSpend"("model");
-- CreateIndex
CREATE INDEX "LiteLLM_DailyEndUserSpend_mcp_namespaced_tool_name_idx" ON "LiteLLM_DailyEndUserSpend"("mcp_namespaced_tool_name");
-- CreateIndex
CREATE UNIQUE INDEX "LiteLLM_DailyEndUserSpend_end_user_id_date_api_key_model_cu_key" ON "LiteLLM_DailyEndUserSpend"("end_user_id", "date", "api_key", "model", "custom_llm_provider", "mcp_namespaced_tool_name");

View File

@@ -465,6 +465,34 @@ model LiteLLM_DailyOrganizationSpend {
@@index([mcp_namespaced_tool_name])
}
// Track daily end user (customer) spend metrics per model and key
model LiteLLM_DailyEndUserSpend {
id String @id @default(uuid())
end_user_id String?
date String
api_key String
model String?
model_group String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
cache_creation_input_tokens BigInt @default(0)
spend Float @default(0.0)
api_requests BigInt @default(0)
successful_requests BigInt @default(0)
failed_requests BigInt @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([end_user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([end_user_id])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTeamSpend {
id String @id @default(uuid())

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm-proxy-extras"
version = "0.4.9"
version = "0.4.10"
description = "Additional files for the LiteLLM Proxy. Reduces the size of the main litellm package."
authors = ["BerriAI"]
readme = "README.md"
@@ -22,7 +22,7 @@ requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.commitizen]
version = "0.4.9"
version = "0.4.10"
version_files = [
"pyproject.toml:version",
"../requirements.txt:litellm-proxy-extras==",

View File

@@ -149,6 +149,7 @@ REDIS_UPDATE_BUFFER_KEY = "litellm_spend_update_buffer"
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_spend_update_buffer"
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_team_spend_update_buffer"
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_org_spend_update_buffer"
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_end_user_spend_update_buffer"
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY = "litellm_daily_tag_spend_update_buffer"
MAX_REDIS_BUFFER_DEQUEUE_COUNT = int(os.getenv("MAX_REDIS_BUFFER_DEQUEUE_COUNT", 100))
MAX_SIZE_IN_MEMORY_QUEUE = int(os.getenv("MAX_SIZE_IN_MEMORY_QUEUE", 10000))

View File

@@ -3639,6 +3639,8 @@ class DailyOrganizationSpendTransaction(BaseDailySpendTransaction):
class DailyUserSpendTransaction(BaseDailySpendTransaction):
user_id: str
class DailyEndUserSpendTransaction(BaseDailySpendTransaction):
end_user_id: str
class DailyTagSpendTransaction(BaseDailySpendTransaction):
request_id: Optional[str]

View File

@@ -25,6 +25,7 @@ from litellm.proxy._types import (
DailyTagSpendTransaction,
DailyOrganizationSpendTransaction,
DailyTeamSpendTransaction,
DailyEndUserSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
Litellm_EntityType,
@@ -65,6 +66,7 @@ class DBSpendUpdateWriter:
self.spend_update_queue = SpendUpdateQueue()
self.daily_spend_update_queue = DailySpendUpdateQueue()
self.daily_team_spend_update_queue = DailySpendUpdateQueue()
self.daily_end_user_spend_update_queue = DailySpendUpdateQueue()
self.daily_org_spend_update_queue = DailySpendUpdateQueue()
self.daily_tag_spend_update_queue = DailySpendUpdateQueue()
@@ -182,6 +184,13 @@ class DBSpendUpdateWriter:
)
)
asyncio.create_task(
self.add_spend_log_transaction_to_daily_end_user_transaction(
payload=payload,
prisma_client=prisma_client,
)
)
asyncio.create_task(
self.add_spend_log_transaction_to_daily_team_transaction(
payload=payload,
@@ -475,6 +484,7 @@ class DBSpendUpdateWriter:
daily_spend_update_queue=self.daily_spend_update_queue,
daily_team_spend_update_queue=self.daily_team_spend_update_queue,
daily_org_spend_update_queue=self.daily_org_spend_update_queue,
daily_end_user_spend_update_queue=self.daily_end_user_spend_update_queue,
daily_tag_spend_update_queue=self.daily_tag_spend_update_queue,
)
@@ -538,6 +548,16 @@ class DBSpendUpdateWriter:
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_tag_spend_update_transactions,
)
daily_end_user_spend_update_transactions = (
await self.redis_update_buffer.get_all_daily_end_user_spend_update_transactions_from_redis_buffer()
)
if daily_end_user_spend_update_transactions is not None:
await DBSpendUpdateWriter.update_daily_end_user_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_end_user_spend_update_transactions,
)
except Exception as e:
verbose_proxy_logger.error(f"Error committing spend updates: {e}")
finally:
@@ -627,6 +647,20 @@ class DBSpendUpdateWriter:
daily_spend_transactions=daily_tag_spend_update_transactions,
)
################## Daily End-User Spend Update Transactions ##################
# Aggregate all in memory daily end-user spend transactions and commit to db
daily_end_user_spend_update_transactions = cast(
Dict[str, DailyEndUserSpendTransaction],
await self.daily_end_user_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions(),
)
await DBSpendUpdateWriter.update_daily_end_user_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_end_user_spend_update_transactions,
)
async def _commit_spend_updates_to_db( # noqa: PLR0915
self,
prisma_client: PrismaClient,
@@ -990,6 +1024,20 @@ class DBSpendUpdateWriter:
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyEndUserSpendTransaction],
entity_type: Literal["end_user"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
...
@overload
@staticmethod
async def _update_daily_spend(
@@ -1015,14 +1063,15 @@ class DBSpendUpdateWriter:
Dict[str, DailyTeamSpendTransaction],
Dict[str, DailyTagSpendTransaction],
Dict[str, DailyOrganizationSpendTransaction],
Dict[str, DailyEndUserSpendTransaction],
],
entity_type: Literal["user", "team", "org", "tag"],
entity_type: Literal["user", "team", "org", "tag", "end_user"],
entity_id_field: str,
table_name: str,
unique_constraint_name: str,
) -> None:
"""
Generic function to update daily spend for any entity type (user, team, org, tag)
Generic function to update daily spend for any entity type (user, team, org, tag, end_user)
"""
from litellm.proxy.utils import _raise_failed_update_spend_exception
@@ -1267,6 +1316,27 @@ class DBSpendUpdateWriter:
unique_constraint_name="organization_id_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name",
)
@staticmethod
async def update_daily_end_user_spend(
n_retry_times: int,
prisma_client: PrismaClient,
proxy_logging_obj: ProxyLogging,
daily_spend_transactions: Dict[str, DailyEndUserSpendTransaction],
):
"""
Batch job to update LiteLLM_DailyEndUserSpend table using in-memory daily_spend_transactions
"""
await DBSpendUpdateWriter._update_daily_spend(
n_retry_times=n_retry_times,
prisma_client=prisma_client,
proxy_logging_obj=proxy_logging_obj,
daily_spend_transactions=daily_spend_transactions,
entity_type="end_user",
entity_id_field="end_user_id",
table_name="litellm_dailyenduserspend",
unique_constraint_name="end_user_id_date_api_key_model_custom_llm_provider_mcp_namespaced_tool_name",
)
@staticmethod
async def update_daily_tag_spend(
n_retry_times: int,
@@ -1292,7 +1362,7 @@ class DBSpendUpdateWriter:
self,
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
type: Literal["user", "team", "org", "request_tags"] = "user",
type: Literal["user", "team", "org", "request_tags", "end_user"] = "user",
) -> Optional[BaseDailySpendTransaction]:
common_expected_keys = ["startTime", "api_key"]
if type == "user":
@@ -1303,6 +1373,8 @@ class DBSpendUpdateWriter:
expected_keys = ["organization_id", *common_expected_keys]
elif type == "request_tags":
expected_keys = ["request_tags", *common_expected_keys]
elif type == "end_user":
expected_keys = ["end_user_id", *common_expected_keys]
else:
raise ValueError(f"Invalid type: {type}")
if not all(key in payload for key in expected_keys):
@@ -1474,6 +1546,48 @@ class DBSpendUpdateWriter:
update={daily_transaction_key: daily_transaction}
)
async def add_spend_log_transaction_to_daily_end_user_transaction(
self,
payload: SpendLogsPayload,
prisma_client: Optional[PrismaClient] = None,
) -> None:
if prisma_client is None:
verbose_proxy_logger.debug(
"prisma_client is None. Skipping writing spend logs to db."
)
return
end_user_id = payload.get("end_user")
if end_user_id is None or end_user_id == "":
verbose_proxy_logger.debug(
"end_user is None or empty for request. Skipping incrementing end user spend."
)
return
payload_with_end_user_id = cast(
SpendLogsPayload,
{
**payload,
"end_user_id": end_user_id,
},
)
base_daily_transaction = (
await self._common_add_spend_log_transaction_to_daily_transaction(
payload_with_end_user_id, prisma_client, "end_user"
)
)
if base_daily_transaction is None:
return
daily_transaction_key = f"{end_user_id}_{base_daily_transaction['date']}_{payload_with_end_user_id['api_key']}_{payload_with_end_user_id['model']}_{payload_with_end_user_id['custom_llm_provider']}"
daily_transaction = DailyEndUserSpendTransaction(
end_user_id=end_user_id, **base_daily_transaction
)
await self.daily_end_user_spend_update_queue.add_update(
update={daily_transaction_key: daily_transaction}
)
async def add_spend_log_transaction_to_daily_tag_transaction(
self,
payload: SpendLogsPayload,

View File

@@ -16,6 +16,7 @@ from litellm.constants import (
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
REDIS_UPDATE_BUFFER_KEY,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
@@ -24,6 +25,7 @@ from litellm.proxy._types import (
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DailyOrganizationSpendTransaction,
DailyEndUserSpendTransaction,
DBSpendUpdateTransactions,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
@@ -107,6 +109,7 @@ class RedisUpdateBuffer:
daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_org_spend_update_queue: DailySpendUpdateQueue,
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
daily_tag_spend_update_queue: DailySpendUpdateQueue,
):
"""
@@ -172,6 +175,9 @@ class RedisUpdateBuffer:
daily_org_spend_update_transactions = (
await daily_org_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_end_user_spend_update_transactions = (
await daily_end_user_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_tag_spend_update_transactions = (
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
@@ -207,6 +213,12 @@ class RedisUpdateBuffer:
service_type=ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
)
await self._store_transactions_in_redis(
transactions=daily_end_user_spend_update_transactions,
redis_key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
service_type=ServiceTypes.REDIS_DAILY_END_USER_SPEND_UPDATE_QUEUE,
)
await self._store_transactions_in_redis(
transactions=daily_tag_spend_update_transactions,
redis_key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
@@ -365,6 +377,30 @@ class RedisUpdateBuffer:
),
)
async def get_all_daily_end_user_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyEndUserSpendTransaction]]:
"""
Gets all the daily end-user spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyEndUserSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_tag_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyTagSpendTransaction]]:

View File

@@ -20,6 +20,10 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
router = APIRouter()
@@ -673,4 +677,78 @@ async def list_end_user(
str(e)
)
)
raise handle_exception_on_proxy(e)
raise handle_exception_on_proxy(e)
@router.get(
"/customer/daily/activity",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
response_model=SpendAnalyticsPaginatedResponse,
)
@router.get(
"/end_user/daily/activity",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def get_customer_daily_activity(
end_user_ids: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
model: Optional[str] = None,
api_key: Optional[str] = None,
page: int = 1,
page_size: int = 10,
exclude_end_user_ids: Optional[str] = None,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Get daily activity for specific organizations or all accessible organizations.
"""
from litellm.proxy.proxy_server import (
prisma_client,
)
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
# Parse comma-separated ids
end_user_ids_list = end_user_ids.split(",") if end_user_ids else None
exclude_end_user_ids_list: Optional[List[str]] = None
if exclude_end_user_ids:
exclude_end_user_ids_list = (
exclude_end_user_ids.split(",") if exclude_end_user_ids else None
)
# Fetch organization aliases for metadata
where_condition = {}
if end_user_ids_list:
where_condition["user_id"] = {"in": list(end_user_ids_list)}
end_user_aliases = await prisma_client.db.litellm_endusertable.find_many(
where=where_condition
)
end_user_alias_metadata = {
e.user_id: {"alias": e.alias}
for e in end_user_aliases
}
# Query daily activity for organizations
return await get_daily_activity(
prisma_client=prisma_client,
table_name="litellm_dailyenduserspend",
entity_id_field="end_user_id",
entity_id=end_user_ids_list,
entity_metadata_field=end_user_alias_metadata,
exclude_entity_ids=exclude_end_user_ids_list,
start_date=start_date,
end_date=end_date,
model=model,
api_key=api_key,
page=page,
page_size=page_size,
)

View File

@@ -465,6 +465,34 @@ model LiteLLM_DailyOrganizationSpend {
@@index([mcp_namespaced_tool_name])
}
// Track daily end user (customer) spend metrics per model and key
model LiteLLM_DailyEndUserSpend {
id String @id @default(uuid())
end_user_id String?
date String
api_key String
model String?
model_group String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
cache_creation_input_tokens BigInt @default(0)
spend Float @default(0.0)
api_requests BigInt @default(0)
successful_requests BigInt @default(0)
failed_requests BigInt @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([end_user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([end_user_id])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTeamSpend {
id String @id @default(uuid())

View File

@@ -59,7 +59,7 @@ websockets = {version = "^15.0.1", optional = true}
boto3 = {version = "1.36.0", optional = true}
redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.9' and python_version < '3.14'"}
mcp = {version = "^1.21.2", optional = true, python = ">=3.10"}
litellm-proxy-extras = {version = "0.4.9", optional = true}
litellm-proxy-extras = {version = "0.4.10", optional = true}
rich = {version = "13.7.1", optional = true}
litellm-enterprise = {version = "0.1.23", optional = true}
diskcache = {version = "^5.6.1", optional = true}

View File

@@ -44,7 +44,7 @@ sentry_sdk==2.21.0 # for sentry error handling
detect-secrets==1.5.0 # Enterprise - secret detection / masking in LLM requests
cryptography==44.0.1
tzdata==2025.1 # IANA time zone database
litellm-proxy-extras==0.4.9 # for proxy extras - e.g. prisma migrations
litellm-proxy-extras==0.4.10 # for proxy extras - e.g. prisma migrations
### LITELLM PACKAGE DEPENDENCIES
python-dotenv==1.0.1 # for env
tiktoken==0.8.0 # for calculating usage

View File

@@ -465,6 +465,34 @@ model LiteLLM_DailyOrganizationSpend {
@@index([mcp_namespaced_tool_name])
}
// Track daily end user (customer) spend metrics per model and key
model LiteLLM_DailyEndUserSpend {
id String @id @default(uuid())
end_user_id String?
date String
api_key String
model String?
model_group String?
custom_llm_provider String?
mcp_namespaced_tool_name String?
prompt_tokens BigInt @default(0)
completion_tokens BigInt @default(0)
cache_read_input_tokens BigInt @default(0)
cache_creation_input_tokens BigInt @default(0)
spend Float @default(0.0)
api_requests BigInt @default(0)
successful_requests BigInt @default(0)
failed_requests BigInt @default(0)
created_at DateTime @default(now())
updated_at DateTime @updatedAt
@@unique([end_user_id, date, api_key, model, custom_llm_provider, mcp_namespaced_tool_name])
@@index([date])
@@index([end_user_id])
@@index([api_key])
@@index([model])
@@index([mcp_namespaced_tool_name])
}
// Track daily team spend metrics per model and key
model LiteLLM_DailyTeamSpend {
id String @id @default(uuid())

View File

@@ -572,4 +572,77 @@ async def test_add_spend_log_transaction_to_daily_org_transaction_skips_when_org
org_id=None,
)
writer.daily_org_spend_update_queue.add_update.assert_not_called()
writer.daily_org_spend_update_queue.add_update.assert_not_called()
@pytest.mark.asyncio
async def test_add_spend_log_transaction_to_daily_end_user_transaction_injects_end_user_id_and_queues_update():
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
mock_prisma.get_request_status = MagicMock(return_value="success")
end_user_id = "end-user-xyz"
payload = {
"request_id": "req-1",
"user": "test-user",
"end_user": end_user_id,
"startTime": "2024-01-01T12:00:00",
"api_key": "test-key",
"model": "gpt-4",
"custom_llm_provider": "openai",
"model_group": "gpt-4-group",
"prompt_tokens": 10,
"completion_tokens": 5,
"spend": 0.2,
"metadata": '{"usage_object": {}}',
}
writer.daily_end_user_spend_update_queue.add_update = AsyncMock()
await writer.add_spend_log_transaction_to_daily_end_user_transaction(
payload=payload,
prisma_client=mock_prisma,
)
writer.daily_end_user_spend_update_queue.add_update.assert_called_once()
call_args = writer.daily_end_user_spend_update_queue.add_update.call_args[1]
update_dict = call_args["update"]
assert len(update_dict) == 1
for key, transaction in update_dict.items():
assert key == f"{end_user_id}_2024-01-01_test-key_gpt-4_openai"
assert transaction["end_user_id"] == end_user_id
assert transaction["date"] == "2024-01-01"
assert transaction["api_key"] == "test-key"
assert transaction["model"] == "gpt-4"
assert transaction["custom_llm_provider"] == "openai"
@pytest.mark.asyncio
async def test_add_spend_log_transaction_to_daily_end_user_transaction_skips_when_end_user_id_missing():
writer = DBSpendUpdateWriter()
mock_prisma = MagicMock()
mock_prisma.get_request_status = MagicMock(return_value="success")
payload = {
"request_id": "req-2",
"user": "test-user",
"startTime": "2024-01-01T12:00:00",
"api_key": "test-key",
"model": "gpt-4",
"custom_llm_provider": "openai",
"model_group": "gpt-4-group",
"prompt_tokens": 10,
"completion_tokens": 5,
"spend": 0.2,
"metadata": '{"usage_object": {}}',
}
writer.daily_end_user_spend_update_queue.add_update = AsyncMock()
await writer.add_spend_log_transaction_to_daily_end_user_transaction(
payload=payload,
prisma_client=mock_prisma,
)
writer.daily_end_user_spend_update_queue.add_update.assert_not_called()

View File

@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI, HTTPException, Request, status
@@ -301,3 +301,99 @@ def test_customer_endpoints_error_schema_consistency(mock_prisma_client, mock_us
for key in ["message", "type", "code"]:
assert isinstance(error1[key], str), f"error1[{key}] should be a string"
assert isinstance(error2[key], str), f"error2[{key}] should be a string"
@pytest.mark.asyncio
async def test_get_customer_daily_activity_admin_param_passing(monkeypatch):
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.management_endpoints import customer_endpoints
from litellm.proxy.management_endpoints.customer_endpoints import (
get_customer_daily_activity,
)
mock_prisma_client = AsyncMock()
mock_prisma_client.db.litellm_endusertable.find_many = AsyncMock(
return_value=[]
)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocked_response = MagicMock(name="SpendAnalyticsPaginatedResponse")
get_daily_activity_mock = AsyncMock(return_value=mocked_response)
monkeypatch.setattr(
customer_endpoints, "get_daily_activity", get_daily_activity_mock
)
auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin1")
result = await get_customer_daily_activity(
end_user_ids="end-user-1,end-user-2",
start_date="2024-01-01",
end_date="2024-01-31",
model="gpt-4",
api_key="test-key",
page=2,
page_size=5,
exclude_end_user_ids="end-user-3",
user_api_key_dict=auth,
)
get_daily_activity_mock.assert_awaited_once()
kwargs = get_daily_activity_mock.call_args.kwargs
assert kwargs["table_name"] == "litellm_dailyenduserspend"
assert kwargs["entity_id_field"] == "end_user_id"
assert kwargs["entity_id"] == ["end-user-1", "end-user-2"]
assert kwargs["exclude_entity_ids"] == ["end-user-3"]
assert kwargs["start_date"] == "2024-01-01"
assert kwargs["end_date"] == "2024-01-31"
assert kwargs["model"] == "gpt-4"
assert kwargs["api_key"] == "test-key"
assert kwargs["page"] == 2
assert kwargs["page_size"] == 5
assert result is mocked_response
@pytest.mark.asyncio
async def test_get_customer_daily_activity_with_end_user_aliases(monkeypatch):
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.management_endpoints import customer_endpoints
from litellm.proxy.management_endpoints.customer_endpoints import (
get_customer_daily_activity,
)
mock_prisma_client = AsyncMock()
mock_end_user1 = MagicMock()
mock_end_user1.user_id = "end-user-1"
mock_end_user1.alias = "Customer One"
mock_end_user2 = MagicMock()
mock_end_user2.user_id = "end-user-2"
mock_end_user2.alias = "Customer Two"
mock_prisma_client.db.litellm_endusertable.find_many = AsyncMock(
return_value=[mock_end_user1, mock_end_user2]
)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
mocked_response = MagicMock(name="SpendAnalyticsPaginatedResponse")
get_daily_activity_mock = AsyncMock(return_value=mocked_response)
monkeypatch.setattr(
customer_endpoints, "get_daily_activity", get_daily_activity_mock
)
auth = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin1")
await get_customer_daily_activity(
end_user_ids="end-user-1,end-user-2",
start_date="2024-01-01",
end_date="2024-01-31",
model=None,
api_key=None,
page=1,
page_size=10,
exclude_end_user_ids=None,
user_api_key_dict=auth,
)
kwargs = get_daily_activity_mock.call_args.kwargs
assert kwargs["entity_metadata_field"] == {
"end-user-1": {"alias": "Customer One"},
"end-user-2": {"alias": "Customer Two"},
}