Compare commits

...

4 Commits

Author SHA1 Message Date
DN6
921b959b9a update 2025-12-01 10:42:35 +05:30
DN6
9391a5465d Merge branch 'main' into flux2-modular 2025-11-28 22:23:38 +05:30
DN6
d780d1a42a update 2025-11-28 15:27:30 +05:30
DN6
9264459f88 update 2025-11-28 12:57:00 +05:30
13 changed files with 2029 additions and 1 deletions

View File

@@ -399,6 +399,8 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
@@ -1091,6 +1093,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,

View File

@@ -52,6 +52,10 @@ else:
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2ModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
@@ -71,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = [
"Flux2TextEncoderStep",
"Flux2RemoteTextEncoderStep",
"Flux2ProcessImagesInputStep",
"Flux2VaeEncoderStep",
]
_import_structure["before_denoise"] = [
"Flux2SetTimestepsStep",
"Flux2PrepareLatentsStep",
"Flux2RoPEInputsStep",
"Flux2PrepareImageLatentsStep",
]
_import_structure["denoise"] = [
"Flux2LoopDenoiser",
"Flux2LoopAfterDenoiser",
"Flux2DenoiseLoopWrapper",
"Flux2DenoiseStep",
]
_import_structure["decoders"] = ["Flux2DecodeStep"]
_import_structure["inputs"] = [
"Flux2TextInputStep",
"Flux2ImageInputStep",
]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBeforeDenoiseStep",
"Flux2AutoBlocks",
"Flux2AutoDecodeStep",
"Flux2AutoDenoiseStep",
"Flux2AutoImageInputStep",
"Flux2AutoTextEncoderStep",
"Flux2AutoTextInputStep",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import (
Flux2DenoiseLoopWrapper,
Flux2DenoiseStep,
Flux2LoopAfterDenoiser,
Flux2LoopDenoiser,
)
from .encoders import (
Flux2ProcessImagesInputStep,
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ImageInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
TEXT2IMAGE_BLOCKS,
Flux2AutoBeforeDenoiseStep,
Flux2AutoBlocks,
Flux2AutoDecodeStep,
Flux2AutoDenoiseStep,
Flux2AutoImageInputStep,
Flux2AutoTextEncoderStep,
Flux2AutoTextInputStep,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,505 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import numpy as np
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
"""Compute empirical mu for Flux2 timestep scheduling."""
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class Flux2SetTimestepsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
height = block_state.height or components.default_height
width = block_state.width or components.default_width
vae_scale_factor = components.vae_scale_factor
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
timesteps = block_state.timesteps
if timesteps is None and sigmas is None:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
]
@staticmethod
def check_inputs(components, block_state):
vae_scale_factor = components.vae_scale_factor
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
)
@staticmethod
def _prepare_latent_ids(latents: torch.Tensor):
"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents: Latent tensor of shape (B, C, H, W)
Returns:
Position IDs tensor of shape (B, H*W, 4)
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1)
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1)
latent_ids = torch.cartesian_prod(t, h, w, l)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
def prepare_latents(
comp,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
latents = self.prepare_latents(
components,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(block_state.device)
latents = self._pack_latents(latents)
block_state.latents = latents
block_state.latent_ids = latent_ids
self.set_block_state(state, block_state)
return components, state
class Flux2RoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image_latents", type_hint=List[torch.Tensor]),
InputParam("batch_size", required=True, type_hint=int),
InputParam("num_images_per_prompt", default=1, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning",
),
OutputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents",
),
]
@staticmethod
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
Args:
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
scale: Factor used to define the time separation between latents.
Returns:
Combined coordinate tensor of shape (1, N_total, 4)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image_latents = block_state.image_latents
if image_latents is None:
block_state.image_latents = None
block_state.image_latent_ids = None
else:
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
image_latent_ids = self._prepare_image_ids(image_latents)
packed_latents = []
for latent in image_latents:
packed = self._pack_latents(latent)
packed = packed.squeeze(0)
packed_latents.append(packed)
image_latents = torch.cat(packed_latents, dim=0)
image_latents = image_latents.unsqueeze(0)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
block_state.image_latents = image_latents
block_state.image_latent_ids = image_latent_ids
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,146 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLFlux2
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="Position IDs for the latents, used for unpacking",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
"""
Unpack latents using position IDs to scatter tokens into place.
Args:
x: Packed latents tensor of shape (B, seq_len, C)
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
Returns:
Unpacked latents tensor of shape (B, C, H, W)
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,252 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2LoopDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Mistral3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that updates the latents after denoising. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoises the latents over `timesteps`. "
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self.set_block_state(state, block_state)
return components, state
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2LoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -0,0 +1,420 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLFlux2
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def format_text_input(prompts: List[str], system_message: str = None):
"""Format prompts for Mistral3 chat template."""
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2TextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
# fmt: off
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
# fmt: on
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
ComponentSpec("tokenizer", AutoProcessor),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from Mistral3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
def _get_mistral_3_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
hidden_states_layers: Tuple[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=block_state.device,
max_sequence_length=block_state.max_sequence_length,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using a remote API endpoint"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from remote API used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
import io
import requests
from huggingface_hub import get_token
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
response = requests.post(
self.REMOTE_URL,
json={"prompt": prompt},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
self.set_block_state(state, block_state)
return components, state
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image"),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
images = block_state.image
if images is None:
block_state.condition_images = None
else:
if not isinstance(images, list):
images = [images]
condition_images = []
for img in images:
components.image_processor.check_image_input(img)
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = components.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
condition_img = components.image_processor.preprocess(
img, height=image_height, width=image_width, resize_mode="crop"
)
condition_images.append(condition_img)
if block_state.height is None:
block_state.height = image_height
if block_state.width is None:
block_state.width = image_width
block_state.condition_images = condition_images
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKLFlux2)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("condition_images", type_hint=List[torch.Tensor]),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=List[torch.Tensor],
description="List of latent representations for each reference image",
),
]
@staticmethod
def _patchify_latents(latents):
"""Convert latents to patchified format for Flux2."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
"""Encode a single image using Flux2 VAE with batch norm normalization."""
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
condition_images = block_state.condition_images
if condition_images is None:
block_state.image_latents = None
else:
device = components._execution_device
dtype = components.vae.dtype
image_latents = []
for image in condition_images:
image = image.to(device=device, dtype=dtype)
latent = self._encode_vae_image(
vae=components.vae,
image=image,
generator=block_state.generator,
)
image_latents.append(latent)
block_state.image_latents = image_latents
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,140 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__)
class Flux2TextInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"Text input processing step that standardizes text embeddings for Flux2 pipeline.\n"
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2ImageInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"Image input processing step that prepares image latents for Flux2 conditioning.\n"
"This step expands image latents to match the batch size."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam("batch_size", required=True, type_hint=int),
InputParam("image_latents", type_hint=torch.Tensor),
InputParam("image_latent_ids", type_hint=torch.Tensor),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents expanded to batch size",
),
OutputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Image latent position IDs expanded to batch size",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image_latents = block_state.image_latents
image_latent_ids = block_state.image_latent_ids
target_batch_size = block_state.batch_size * block_state.num_images_per_prompt
if image_latents is not None:
block_state.image_latents = image_latents.repeat(target_batch_size, 1, 1)
if image_latent_ids is not None:
block_state.image_latent_ids = image_latent_ids.repeat(target_batch_size, 1, 1)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,237 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2ProcessImagesInputStep,
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ImageInputStep,
Flux2TextInputStep,
)
class Flux2AutoTextInputStep(AutoPipelineBlocks):
block_classes = [Flux2TextInputStep]
block_names = ["text_input"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Text input step that processes text embeddings and determines batch size.\n"
" - `Flux2TextInputStep` is always used."
)
class Flux2AutoImageInputStep(AutoPipelineBlocks):
block_classes = [Flux2ImageInputStep]
block_names = ["image_input"]
block_trigger_inputs = ["image_latents"]
@property
def description(self):
return (
"Image input step that expands image latents to match batch size.\n"
" - `Flux2ImageInputStep` is used when `image_latents` is provided.\n"
" - Skipped when no image conditioning is used."
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2VaeEncoderBlocks.values()
block_names = Flux2VaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2VaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
class Flux2AutoTextEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2RemoteTextEncoderStep, Flux2TextEncoderStep]
block_names = ["remote", "local"]
block_trigger_inputs = ["remote_text_encoder", None]
@property
def description(self):
return (
"Text encoder step that generates text embeddings to guide the image generation.\n"
"This is an auto pipeline block that selects between local and remote text encoding.\n"
" - `Flux2RemoteTextEncoderStep` is used when `remote_text_encoder=True`.\n"
" - `Flux2TextEncoderStep` is used otherwise (default)."
)
Flux2BeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
class Flux2AutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "flux2"
block_classes = [Flux2BeforeDenoiseStep]
block_names = ["before_denoise"]
block_trigger_inputs = [None]
@property
def description(self):
return (
"Before denoise step that prepares the inputs for the denoise step.\n"
"This is an auto pipeline block for Flux2.\n"
" - `Flux2BeforeDenoiseStep` is used for both text-to-image and image-conditioned generation."
)
class Flux2AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [Flux2DenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents. "
"This is an auto pipeline block that works for Flux2 text-to-image and image-conditioned tasks."
" - `Flux2DenoiseStep` (denoise) for text-to-image and image-conditioned tasks."
)
class Flux2AutoDecodeStep(AutoPipelineBlocks):
block_classes = [Flux2DecodeStep]
block_names = ["decode"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`"
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2AutoTextEncoderStep()),
("text_input", Flux2AutoTextInputStep()),
("image_encoder", Flux2AutoVaeEncoderStep()),
("image_input", Flux2AutoImageInputStep()),
("before_denoise", Flux2AutoBeforeDenoiseStep()),
("denoise", Flux2AutoDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
class Flux2AutoBlocks(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
"- For text-to-image generation, all you need to provide is `prompt`.\n"
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("preprocess_images", Flux2ProcessImagesInputStep()),
("vae_encoder", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("image_input", Flux2ImageInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
"auto": AUTO_BLOCKS,
}

View File

@@ -0,0 +1,57 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2AutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents

View File

@@ -58,6 +58,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
@@ -1585,7 +1586,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property

View File

@@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import tempfile
import unittest
import numpy as np
import PIL
import torch
from diffusers.modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
ModularPipeline,
)
from diffusers.modular_pipelines.flux2 import (
Flux2AutoTextEncoderStep,
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_save_from_pretrained(self):
pipes = []
base_pipe = self.get_pipeline().to(torch_device)
pipes.append(base_pipe)
with tempfile.TemporaryDirectory() as tmpdirname:
base_pipe.save_pretrained(tmpdirname)
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
pipe.load_components(torch_dtype=torch.float32)
pipe.to(torch_device)
pipes.append(pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2AutoTextEncoderStep(unittest.TestCase):
def test_auto_text_encoder_block_classes(self):
auto_step = Flux2AutoTextEncoderStep()
assert len(auto_step.block_classes) == 2
assert Flux2RemoteTextEncoderStep in auto_step.block_classes
assert Flux2TextEncoderStep in auto_step.block_classes
def test_auto_text_encoder_trigger_inputs(self):
auto_step = Flux2AutoTextEncoderStep()
assert auto_step.block_trigger_inputs == ["remote_text_encoder", None]
def test_auto_text_encoder_block_names(self):
auto_step = Flux2AutoTextEncoderStep()
assert auto_step.block_names == ["remote", "local"]