[feat]: implement "local" caption upsampling for Flux.2 (#12718)

* feat: implement caption upsampling for flux.2.

* doc

* up

* fix

* up

* fix system prompts 🤷‍

* up

* up

* up
This commit is contained in:
Sayak Paul
2025-12-02 04:27:24 +05:30
committed by GitHub
parent 394a48d169
commit 564079f295
5 changed files with 255 additions and 23 deletions

View File

@@ -26,6 +26,12 @@ Original model checkpoints for Flux can be found [here](https://huggingface.co/b
> >
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs. > [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## Caption upsampling
Flux.2 can potentially generate better better outputs with better prompts. We can "upsample"
an input prompt by setting the `caption_upsample_temperature` argument in the pipeline call arguments.
The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L140) recommends this value to be 0.15.
## Flux2Pipeline ## Flux2Pipeline
[[autodoc]] Flux2Pipeline [[autodoc]] Flux2Pipeline

View File

@@ -1,5 +1,8 @@
[tool.ruff] [tool.ruff]
line-length = 119 line-length = 119
extend-exclude = [
"src/diffusers/pipelines/flux2/system_messages.py",
]
[tool.ruff.lint] [tool.ruff.lint]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).

View File

@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import Tuple from typing import List
import PIL.Image import PIL.Image
@@ -98,7 +98,7 @@ class Flux2ImageProcessor(VaeImageProcessor):
return image return image
@staticmethod @staticmethod
def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]: def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size image_width, image_height = image.size
scale = math.sqrt(target_area / (image_width * image_height)) scale = math.sqrt(target_area / (image_width * image_height))
@@ -107,6 +107,14 @@ class Flux2ImageProcessor(VaeImageProcessor):
return image.resize((width, height), PIL.Image.Resampling.LANCZOS) return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
@staticmethod
def _resize_if_exceeds_area(image, target_area=1024 * 1024) -> PIL.Image.Image:
image_width, image_height = image.size
pixel_count = image_width * image_height
if pixel_count <= target_area:
return image
return Flux2ImageProcessor._resize_to_target_area(image, target_area)
def _resize_and_crop( def _resize_and_crop(
self, self,
image: PIL.Image.Image, image: PIL.Image.Image,
@@ -136,3 +144,35 @@ class Flux2ImageProcessor(VaeImageProcessor):
bottom = top + height bottom = top + height
return image.crop((left, top, right, bottom)) return image.crop((left, top, right, bottom))
# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L310C1-L339C19
@staticmethod
def concatenate_images(images: List[PIL.Image.Image]) -> PIL.Image.Image:
"""
Concatenate a list of PIL images horizontally with center alignment and white background.
"""
# If only one image, return a copy of it
if len(images) == 1:
return images[0].copy()
# Convert all images to RGB if not already
images = [img.convert("RGB") if img.mode != "RGB" else img for img in images]
# Calculate dimensions for horizontal concatenation
total_width = sum(img.width for img in images)
max_height = max(img.height for img in images)
# Create new image with white background
background_color = (255, 255, 255)
new_img = PIL.Image.new("RGB", (total_width, max_height), background_color)
# Paste images with center alignment
x_offset = 0
for img in images:
y_offset = (max_height - img.height) // 2
new_img.paste(img, (x_offset, y_offset))
x_offset += img.width
return new_img

View File

@@ -28,6 +28,7 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput from .pipeline_output import Flux2PipelineOutput
from .system_messages import SYSTEM_MESSAGE, SYSTEM_MESSAGE_UPSAMPLING_I2I, SYSTEM_MESSAGE_UPSAMPLING_T2I
if is_torch_xla_available(): if is_torch_xla_available():
@@ -56,25 +57,105 @@ EXAMPLE_DOC_STRING = """
``` ```
""" """
UPSAMPLING_MAX_IMAGE_SIZE = 768**2
def format_text_input(prompts: List[str], system_message: str = None):
# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
def format_input(
prompts: List[str],
system_message: str = SYSTEM_MESSAGE,
images: Optional[Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None,
):
"""
Format a batch of text prompts into the conversation format expected by apply_chat_template. Optionally, add images
to the input.
Args:
prompts: List of text prompts
system_message: System message to use (default: CREATIVE_SYSTEM_MESSAGE)
images (optional): List of images to add to the input.
Returns:
List of conversations, where each conversation is a list of message dicts
"""
# Remove [IMG] tokens from prompts to avoid Pixtral validation issues # Remove [IMG] tokens from prompts to avoid Pixtral validation issues
# when truncation is enabled. The processor counts [IMG] tokens and fails # when truncation is enabled. The processor counts [IMG] tokens and fails
# if the count changes after truncation. # if the count changes after truncation.
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [ if images is None or len(images) == 0:
[ return [
{ [
"role": "system", {
"content": [{"type": "text", "text": system_message}], "role": "system",
}, "content": [{"type": "text", "text": system_message}],
{"role": "user", "content": [{"type": "text", "text": prompt}]}, },
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
] ]
for prompt in cleaned_txt else:
assert len(images) == len(prompts), "Number of images must match number of prompts"
messages = [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
]
for _ in cleaned_txt
]
for i, (el, images) in enumerate(zip(messages, images)):
# optionally add the images per batch element.
if images is not None:
el.append(
{
"role": "user",
"content": [{"type": "image", "image": image_obj} for image_obj in images],
}
)
# add the text.
el.append(
{
"role": "user",
"content": [{"type": "text", "text": cleaned_txt[i]}],
}
)
return messages
# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L49C5-L66C19
def _validate_and_process_images(
images: List[List[PIL.Image.Image]] | List[PIL.Image.Image],
image_processor: Flux2ImageProcessor,
upsampling_max_image_size: int,
) -> List[List[PIL.Image.Image]]:
# Simple validation: ensure it's a list of PIL images or list of lists of PIL images
if not images:
return []
# Check if it's a list of lists or a list of images
if isinstance(images[0], PIL.Image.Image):
# It's a list of images, convert to list of lists
images = [[im] for im in images]
# potentially concatenate multiple images to reduce the size
images = [[image_processor.concatenate_images(img_i)] if len(img_i) > 1 else img_i for img_i in images]
# cap the pixels
images = [
[image_processor._resize_if_exceeds_area(img_i, upsampling_max_image_size) for img_i in img_i]
for img_i in images
] ]
return images
# Taken from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/sampling.py#L251
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333 a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666 a2, b2 = 0.00016927, 0.45666666
@@ -214,9 +295,10 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
self.tokenizer_max_length = 512 self.tokenizer_max_length = 512
self.default_sample_size = 128 self.default_sample_size = 128
# fmt: off self.system_message = SYSTEM_MESSAGE
self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." self.system_message_upsampling_t2i = SYSTEM_MESSAGE_UPSAMPLING_T2I
# fmt: on self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE
@staticmethod @staticmethod
def _get_mistral_3_small_prompt_embeds( def _get_mistral_3_small_prompt_embeds(
@@ -226,9 +308,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
max_sequence_length: int = 512, max_sequence_length: int = 512,
# fmt: off system_message: str = SYSTEM_MESSAGE,
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
hidden_states_layers: List[int] = (10, 20, 30), hidden_states_layers: List[int] = (10, 20, 30),
): ):
dtype = text_encoder.dtype if dtype is None else dtype dtype = text_encoder.dtype if dtype is None else dtype
@@ -237,7 +317,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
# Format input messages # Format input messages
messages_batch = format_text_input(prompts=prompt, system_message=system_message) messages_batch = format_input(prompts=prompt, system_message=system_message)
# Process all messages at once # Process all messages at once
inputs = tokenizer.apply_chat_template( inputs = tokenizer.apply_chat_template(
@@ -426,6 +506,68 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
return torch.stack(x_list, dim=0) return torch.stack(x_list, dim=0)
def upsample_prompt(
self,
prompt: Union[str, List[str]],
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]] = None,
temperature: float = 0.15,
device: torch.device = None,
) -> List[str]:
prompt = [prompt] if isinstance(prompt, str) else prompt
device = self.text_encoder.device if device is None else device
# Set system message based on whether images are provided
if images is None or len(images) == 0 or images[0] is None:
system_message = SYSTEM_MESSAGE_UPSAMPLING_T2I
else:
system_message = SYSTEM_MESSAGE_UPSAMPLING_I2I
# Validate and process the input images
if images:
images = _validate_and_process_images(images, self.image_processor, self.upsampling_max_image_size)
# Format input messages
messages_batch = format_input(prompts=prompt, system_message=system_message, images=images)
# Process all messages at once
# with image processing a too short max length can throw an error in here.
inputs = self.tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=2048,
)
# Move to device
inputs["input_ids"] = inputs["input_ids"].to(device)
inputs["attention_mask"] = inputs["attention_mask"].to(device)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(device, self.text_encoder.dtype)
# Generate text using the model's generate method
generated_ids = self.text_encoder.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=temperature,
use_cache=True,
)
# Decode only the newly generated tokens (skip input tokens)
# Extract only the generated portion
input_length = inputs["input_ids"].shape[1]
generated_tokens = generated_ids[:, input_length:]
upsampled_prompt = self.tokenizer.tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
return upsampled_prompt
def encode_prompt( def encode_prompt(
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
@@ -620,6 +762,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512, max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int] = (10, 20, 30), text_encoder_out_layers: Tuple[int] = (10, 20, 30),
caption_upsample_temperature: float = None,
): ):
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
@@ -635,11 +778,11 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
guidance_scale (`float`, *optional*, defaults to 1.0): guidance_scale (`float`, *optional*, defaults to 1.0):
Guidance scale as defined in [Classifier-Free Diffusion Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. a model to generate images more aligned with `prompt` at the expense of lower image quality.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
the text `prompt`, usually at the expense of lower image quality. the [paper](https://huggingface.co/papers/2210.03142) to learn more.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results. The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -684,6 +827,9 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
text_encoder_out_layers (`Tuple[int]`): text_encoder_out_layers (`Tuple[int]`):
Layer indices to use in the `text_encoder` to derive the final prompt embeddings. Layer indices to use in the `text_encoder` to derive the final prompt embeddings.
caption_upsample_temperature (`float`):
When specified, we will try to perform caption upsampling for potentially improved outputs. We
recommend setting it to 0.15 if caption upsampling is to be performed.
Examples: Examples:
@@ -718,6 +864,10 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
device = self._execution_device device = self._execution_device
# 3. prepare text embeddings # 3. prepare text embeddings
if caption_upsample_temperature:
prompt = self.upsample_prompt(
prompt, images=image, temperature=caption_upsample_temperature, device=device
)
prompt_embeds, text_ids = self.encode_prompt( prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt, prompt=prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,

View File

@@ -0,0 +1,33 @@
# docstyle-ignore
"""
These system prompts come from:
https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/system_messages.py#L54
"""
# docstyle-ignore
SYSTEM_MESSAGE = """You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object
attribution and actions without speculation."""
# docstyle-ignore
SYSTEM_MESSAGE_UPSAMPLING_T2I = """You are an expert prompt engineer for FLUX.2 by Black Forest Labs. Rewrite user prompts to be more descriptive while strictly preserving their core subject and intent.
Guidelines:
1. Structure: Keep structured inputs structured (enhance within fields). Convert natural language to detailed paragraphs.
2. Details: Add concrete visual specifics - form, scale, textures, materials, lighting (quality, direction, color), shadows, spatial relationships, and environmental context.
3. Text in Images: Put ALL text in quotation marks, matching the prompt's language. Always provide explicit quoted text for objects that would contain text in reality (signs, labels, screens, etc.) - without it, the model generates gibberish.
Output only the revised prompt and nothing else."""
# docstyle-ignore
SYSTEM_MESSAGE_UPSAMPLING_I2I = """You are FLUX.2 by Black Forest Labs, an image-editing expert. You convert editing requests into one concise instruction (50-80 words, ~30 for brief requests).
Rules:
- Single instruction only, no commentary
- Use clear, analytical language (avoid "whimsical," "cascading," etc.)
- Specify what changes AND what stays the same (face, lighting, composition)
- Reference actual image elements
- Turn negatives into positives ("don't change X""keep X")
- Make abstractions concrete ("futuristic""glowing cyan neon, metallic panels")
- Keep content PG-13
Output only the final instruction in plain text and nothing else."""