Merge pull request #17562 from BerriAI/litellm_ui_compare_images

[Feature] Support Images in Compare UI
This commit is contained in:
yuneng-jiang
2025-12-05 17:24:05 -08:00
committed by GitHub
6 changed files with 230 additions and 31 deletions

View File

@@ -2,6 +2,7 @@ import { render, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { beforeEach, describe, expect, it, vi } from "vitest";
import CompareUI from "./CompareUI";
import { makeOpenAIChatCompletionRequest } from "../llm_calls/chat_completion";
vi.mock("../llm_calls/fetch_models", () => ({
fetchAvailableModels: vi.fn().mockResolvedValue([{ model_group: "gpt-4" }, { model_group: "gpt-3.5-turbo" }]),
@@ -11,6 +12,34 @@ vi.mock("../llm_calls/chat_completion", () => ({
makeOpenAIChatCompletionRequest: vi.fn().mockResolvedValue(undefined),
}));
let capturedOnImageUpload: ((file: File) => false) | null = null;
vi.mock("../chat_ui/ChatImageUpload", () => ({
default: ({ onImageUpload }: { onImageUpload: (file: File) => false }) => {
capturedOnImageUpload = onImageUpload;
return (
<div data-testid="chat-image-upload">
<button data-testid="trigger-upload">Upload</button>
</div>
);
},
}));
vi.mock("../chat_ui/ChatImageUtils", () => ({
createChatMultimodalMessage: vi.fn().mockResolvedValue({
role: "user",
content: [
{ type: "text", text: "test message" },
{ type: "image_url", image_url: { url: "" } },
],
}),
createChatDisplayMessage: vi.fn().mockReturnValue({
role: "user",
content: "test message [Image attached]",
imagePreviewUrl: "blob:test-url",
}),
}));
vi.mock("./components/ComparisonPanel", () => ({
ComparisonPanel: ({ comparison, onRemove }: { comparison: any; onRemove: () => void }) => (
<div data-testid={`comparison-panel-${comparison.id}`}>
@@ -22,8 +51,9 @@ vi.mock("./components/ComparisonPanel", () => ({
}));
vi.mock("./components/MessageInput", () => ({
MessageInput: ({ value, onChange, onSend, disabled }: any) => (
MessageInput: ({ value, onChange, onSend, disabled, hasAttachment, uploadComponent }: any) => (
<div data-testid="message-input">
{uploadComponent && <div data-testid="upload-component">{uploadComponent}</div>}
<textarea
data-testid="message-textarea"
value={value}
@@ -33,6 +63,7 @@ vi.mock("./components/MessageInput", () => ({
<button data-testid="send-button" onClick={onSend} disabled={disabled}>
Send
</button>
{hasAttachment && <div data-testid="has-attachment">Attachment</div>}
</div>
),
}));
@@ -51,6 +82,10 @@ beforeEach(() => {
dispatchEvent: () => false,
}),
});
global.URL.createObjectURL = vi.fn().mockReturnValue("blob:test-url");
global.URL.revokeObjectURL = vi.fn();
capturedOnImageUpload = null;
vi.clearAllMocks();
});
describe("CompareUI", () => {
@@ -88,4 +123,36 @@ describe("CompareUI", () => {
expect(getByTestId("comparison-panel-1")).toBeInTheDocument();
expect(getByTestId("comparison-panel-2")).toBeInTheDocument();
});
it("should handle image upload and send message with attachment", async () => {
const user = userEvent.setup();
const { getByTestId, queryByTestId } = render(
<CompareUI accessToken="test-token" disabledPersonalKeyCreation={false} />,
);
const file = new File(["test content"], "test-image.png", { type: "image/png" });
await waitFor(() => {
expect(capturedOnImageUpload).not.toBeNull();
});
if (capturedOnImageUpload) {
capturedOnImageUpload(file);
}
await waitFor(() => {
expect(queryByTestId("has-attachment")).toBeInTheDocument();
});
const textarea = getByTestId("message-textarea");
await user.type(textarea, "Describe this image");
const sendButton = getByTestId("send-button");
expect(sendButton).not.toBeDisabled();
await user.click(sendButton);
await waitFor(() => {
expect(makeOpenAIChatCompletionRequest).toHaveBeenCalled();
});
});
});

View File

@@ -1,14 +1,16 @@
"use client";
import React, { useEffect, useMemo, useState } from "react";
import { v4 as uuidv4 } from "uuid";
import { Select, Input, Tooltip, Button } from "antd";
import { ClearOutlined, PlusOutlined } from "@ant-design/icons";
import NotificationsManager from "@/components/molecules/notifications_manager";
import { fetchAvailableModels } from "../llm_calls/fetch_models";
import { makeOpenAIChatCompletionRequest } from "../llm_calls/chat_completion";
import { ClearOutlined, DeleteOutlined, FilePdfOutlined, PlusOutlined } from "@ant-design/icons";
import { Button, Input, Select, Tooltip } from "antd";
import { useEffect, useMemo, useState } from "react";
import { v4 as uuidv4 } from "uuid";
import ChatImageUpload from "../chat_ui/ChatImageUpload";
import { createChatDisplayMessage, createChatMultimodalMessage } from "../chat_ui/ChatImageUtils";
import type { TokenUsage } from "../chat_ui/ResponseMetrics";
import type { MessageType, VectorStoreSearchResponse } from "../chat_ui/types";
import { makeOpenAIChatCompletionRequest } from "../llm_calls/chat_completion";
import { fetchAvailableModels } from "../llm_calls/fetch_models";
import { ComparisonPanel } from "./components/ComparisonPanel";
import { MessageInput } from "./components/MessageInput";
export interface ComparisonInstance {
@@ -71,6 +73,8 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
const [modelOptions, setModelOptions] = useState<string[]>([]);
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [inputValue, setInputValue] = useState("");
const [uploadedFile, setUploadedFile] = useState<File | null>(null);
const [uploadedFilePreviewUrl, setUploadedFilePreviewUrl] = useState<string | null>(null);
const [apiKeySource, setApiKeySource] = useState<"session" | "custom">(
disabledPersonalKeyCreation ? "custom" : "session",
);
@@ -82,6 +86,13 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
}, 300);
return () => clearTimeout(timer);
}, [customApiKey]);
useEffect(() => {
return () => {
if (uploadedFilePreviewUrl) {
URL.revokeObjectURL(uploadedFilePreviewUrl);
}
};
}, [uploadedFilePreviewUrl]);
const effectiveApiKey = useMemo(
() => (apiKeySource === "session" ? accessToken || "" : debouncedCustomApiKey.trim()),
[apiKeySource, accessToken, debouncedCustomApiKey],
@@ -215,6 +226,21 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
);
});
};
const handleFileUpload = (file: File): false => {
if (uploadedFilePreviewUrl) {
URL.revokeObjectURL(uploadedFilePreviewUrl);
}
setUploadedFile(file);
setUploadedFilePreviewUrl(URL.createObjectURL(file));
return false;
};
const handleRemoveFile = () => {
if (uploadedFilePreviewUrl) {
URL.revokeObjectURL(uploadedFilePreviewUrl);
}
setUploadedFile(null);
setUploadedFilePreviewUrl(null);
};
const clearAllChats = () => {
setComparisons((prev) =>
prev.map((comparison) => ({
@@ -225,6 +251,7 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
})),
);
setInputValue("");
handleRemoveFile();
};
const appendAssistantChunk = (comparisonId: string, chunk: string, model?: string) => {
if (!chunk) {
@@ -389,9 +416,10 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
);
};
const canUseSessionKey = Boolean(accessToken);
const handleSendMessage = (input: string) => {
const handleSendMessage = async (input: string) => {
const trimmed = input.trim();
if (!trimmed) {
const hasAttachment = Boolean(uploadedFile);
if (!trimmed && !hasAttachment) {
return;
}
if (!effectiveApiKey) {
@@ -406,6 +434,17 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
NotificationsManager.fromBackend("Select a model before sending a message.");
return;
}
const apiUserMessage = hasAttachment
? await createChatMultimodalMessage(trimmed, uploadedFile as File)
: { role: "user", content: trimmed };
const displayUserMessage = createChatDisplayMessage(
trimmed,
hasAttachment,
uploadedFilePreviewUrl || undefined,
uploadedFile?.name,
);
const preparedTargets = new Map<
string,
{
@@ -417,15 +456,19 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
guardrails: string[];
temperature: number;
maxTokens: number;
messages: MessageType[];
displayMessages: MessageType[];
apiChatHistory: Array<{ role: string; content: string | any[] }>;
}
>();
targetComparisons.forEach((comparison) => {
const traceId = comparison.traceId ?? uuidv4();
const userMessage: MessageType = {
role: "user",
content: trimmed,
};
const apiChatHistory = [
...comparison.messages.map(({ role, content }) => ({
role,
content: Array.isArray(content) ? content : typeof content === "string" ? content : "",
})),
apiUserMessage,
];
preparedTargets.set(comparison.id, {
id: comparison.id,
model: comparison.model,
@@ -435,7 +478,8 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
guardrails: comparison.guardrails,
temperature: comparison.temperature,
maxTokens: comparison.maxTokens,
messages: [...comparison.messages, userMessage],
displayMessages: [...comparison.messages, displayUserMessage],
apiChatHistory,
});
});
if (preparedTargets.size === 0) {
@@ -450,23 +494,22 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
return {
...comparison,
traceId: prepared.traceId,
messages: prepared.messages,
messages: prepared.displayMessages,
isLoading: true,
};
}),
);
setInputValue("");
handleRemoveFile();
preparedTargets.forEach((prepared) => {
const apiChatHistory = prepared.messages.map(({ role, content }) => ({
role,
content: typeof content === "string" ? content : "",
}));
const tags = prepared.tags.length > 0 ? prepared.tags : undefined;
const vectorStoreIds = prepared.vectorStores.length > 0 ? prepared.vectorStores : undefined;
const guardrails = prepared.guardrails.length > 0 ? prepared.guardrails : undefined;
const comparison = comparisons.find((c) => c.id === prepared.id);
const useAdvancedParams = comparison?.useAdvancedParams ?? false;
makeOpenAIChatCompletionRequest(
apiChatHistory,
prepared.apiChatHistory,
(chunk, model) => appendAssistantChunk(prepared.id, chunk, model),
prepared.model,
effectiveApiKey,
@@ -536,18 +579,19 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
setInputValue(value);
};
const handleSubmit = () => {
handleSendMessage(inputValue);
setInputValue("");
void handleSendMessage(inputValue);
};
const handleFollowUpSelect = (question: string) => {
setInputValue(question);
};
const hasMessages = comparisons.some((comparison) => comparison.messages.length > 0);
const isAnyComparisonLoading = comparisons.some((comparison) => comparison.isLoading);
const showSuggestedPrompts = !hasMessages && !isAnyComparisonLoading;
const hasAttachment = Boolean(uploadedFile);
const isUploadedFilePdf = Boolean(uploadedFile?.name.toLowerCase().endsWith(".pdf"));
const showSuggestedPrompts = !hasMessages && !isAnyComparisonLoading && !hasAttachment;
return (
<div className="w-full h-full p-4 bg-white">
<div className="rounded-2xl border border-gray-200 bg-white shadow-sm min-h-[calc(100vh-140px)] flex flex-col">
<div className="rounded-2xl border border-gray-200 bg-white shadow-sm min-h-[calc(100vh-160px)] flex flex-col">
<div className="border-b px-4 py-2">
<div className="flex flex-wrap items-center justify-between gap-3">
<div className="flex items-center gap-2">
@@ -620,7 +664,9 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
<div className="w-full max-w-3xl px-4">
<div className="border border-gray-200 shadow-lg rounded-xl bg-white p-4">
<div className="flex items-center justify-between gap-4 mb-3 min-h-8">
{showSuggestedPrompts ? (
{hasAttachment ? (
<span className="text-sm text-gray-500">Attachment ready to send</span>
) : showSuggestedPrompts ? (
<div className="flex items-center gap-2 overflow-x-auto">
{SUGGESTED_PROMPTS.map((prompt) => (
<button
@@ -633,7 +679,7 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
</button>
))}
</div>
) : haveAllResponses ? (
) : haveAllResponses && !hasAttachment ? (
<div className="flex items-center gap-2 overflow-x-auto">
{GENERIC_FOLLOW_UPS.map((question) => (
<button
@@ -655,11 +701,49 @@ export default function CompareUI({ accessToken, disabledPersonalKeyCreation }:
<span className="text-sm text-gray-500">Send a prompt to compare models</span>
)}
</div>
{uploadedFile && (
<div className="mb-3">
<div className="flex items-center gap-3 p-3 bg-gray-50 rounded-lg border border-gray-200">
<div className="relative inline-block">
{isUploadedFilePdf ? (
<div className="w-10 h-10 rounded-md bg-red-500 flex items-center justify-center">
<FilePdfOutlined style={{ fontSize: "16px", color: "white" }} />
</div>
) : (
<img
src={uploadedFilePreviewUrl || ""}
alt="Upload preview"
className="w-10 h-10 rounded-md border border-gray-200 object-cover"
/>
)}
</div>
<div className="flex-1 min-w-0">
<div className="text-sm font-medium text-gray-900 truncate">{uploadedFile.name}</div>
<div className="text-xs text-gray-500">{isUploadedFilePdf ? "PDF" : "Image"}</div>
</div>
<button
className="flex items-center justify-center w-6 h-6 text-gray-400 hover:text-gray-600 hover:bg-gray-200 rounded-full transition-colors"
onClick={handleRemoveFile}
>
<DeleteOutlined style={{ fontSize: "12px" }} />
</button>
</div>
</div>
)}
<MessageInput
value={inputValue}
onChange={handleInputChange}
onSend={handleSubmit}
disabled={comparisons.length === 0 || comparisons.every((comparison) => comparison.isLoading)}
hasAttachment={hasAttachment}
uploadComponent={
<ChatImageUpload
chatUploadedImage={uploadedFile}
chatImagePreviewUrl={uploadedFilePreviewUrl}
onImageUpload={handleFileUpload}
onRemoveImage={handleRemoveFile}
/>
}
/>
</div>
</div>

View File

@@ -17,6 +17,15 @@ vi.mock("../../chat_ui/SearchResultsDisplay", () => ({
SearchResultsDisplay: () => <div data-testid="search-results">SearchResultsDisplay</div>,
}));
vi.mock("../../chat_ui/ChatImageRenderer", () => ({
default: ({ message }: { message: any }) =>
message.imagePreviewUrl ? (
<div data-testid="chat-image-renderer">
<img src={message.imagePreviewUrl} alt="User uploaded image" />
</div>
) : null,
}));
describe("MessageDisplay", () => {
it("should render", () => {
const messages: MessageType[] = [
@@ -63,4 +72,24 @@ describe("MessageDisplay", () => {
expect(getByText("2+2 equals 4")).toBeInTheDocument();
expect(getByTestId("response-metrics")).toBeInTheDocument();
});
it("should display image attachment in user message", () => {
const messages: MessageType[] = [
{
role: "user",
content: "What is in this image? [Image attached]",
imagePreviewUrl: "blob:test-image-url",
},
{
role: "assistant",
content: "This is a test image",
model: "gpt-4",
},
];
const { getByTestId, getByText } = render(<MessageDisplay messages={messages} isLoading={false} />);
expect(getByText("What is in this image? [Image attached]")).toBeInTheDocument();
expect(getByTestId("chat-image-renderer")).toBeInTheDocument();
const image = getByTestId("chat-image-renderer").querySelector("img");
expect(image).toHaveAttribute("src", "blob:test-image-url");
});
});

View File

@@ -1,8 +1,9 @@
import { Bot, Loader2, UserRound } from "lucide-react";
import React from "react";
import ReactMarkdown from "react-markdown";
import { Prism as SyntaxHighlighter } from "react-syntax-highlighter";
import { coy } from "react-syntax-highlighter/dist/esm/styles/prism";
import { Bot, Loader2, UserRound } from "lucide-react";
import ChatImageRenderer from "../../chat_ui/ChatImageRenderer";
import ReasoningContent from "../../chat_ui/ReasoningContent";
import ResponseMetrics from "../../chat_ui/ResponseMetrics";
import { SearchResultsDisplay } from "../../chat_ui/SearchResultsDisplay";
@@ -56,6 +57,7 @@ export function MessageDisplay({ messages, isLoading }: MessageDisplayProps) {
hyphens: "auto",
}}
>
<ChatImageRenderer message={message} />
<ReactMarkdown
components={{
code({

View File

@@ -21,4 +21,16 @@ describe("MessageInput", () => {
expect(button).toBeDisabled();
});
it("should enable send button when hasAttachment is true even with empty value", () => {
const onChange = vi.fn();
const onSend = vi.fn();
const uploadComponent = <div data-testid="upload-component">Upload</div>;
const { container, getByTestId } = render(
<MessageInput value="" onChange={onChange} onSend={onSend} hasAttachment={true} uploadComponent={uploadComponent} />,
);
const button = container.querySelector("button") as HTMLButtonElement;
expect(getByTestId("upload-component")).toBeInTheDocument();
expect(button).not.toBeDisabled();
});
});

View File

@@ -9,13 +9,17 @@ interface MessageInputProps {
onChange: (value: string) => void;
onSend: () => void;
disabled?: boolean;
hasAttachment?: boolean;
uploadComponent?: React.ReactNode;
}
export function MessageInput({ value, onChange, onSend, disabled }: MessageInputProps) {
export function MessageInput({ value, onChange, onSend, disabled, hasAttachment, uploadComponent }: MessageInputProps) {
const canSend = !disabled && (value.trim().length > 0 || Boolean(hasAttachment));
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (e.key === "Enter" && !e.shiftKey) {
e.preventDefault();
if (!disabled && value.trim()) {
if (canSend) {
onSend();
}
}
@@ -24,6 +28,7 @@ export function MessageInput({ value, onChange, onSend, disabled }: MessageInput
return (
<div className="flex items-center gap-2">
<div className="flex items-center flex-1 bg-white border border-gray-300 rounded-xl px-3 py-1 min-h-[44px]">
{uploadComponent && <div className="flex-shrink-0 mr-2">{uploadComponent}</div>}
<TextArea
value={value}
onChange={(e) => onChange(e.target.value)}
@@ -42,7 +47,7 @@ export function MessageInput({ value, onChange, onSend, disabled }: MessageInput
lineHeight: "20px",
}}
/>
<Button onClick={onSend} disabled={disabled || !value.trim()} icon={<ArrowUpOutlined />} shape="circle" />
<Button onClick={onSend} disabled={!canSend} icon={<ArrowUpOutlined />} shape="circle" />
</div>
</div>
);