Compare commits

...

15 Commits

Author SHA1 Message Date
yiyixuxu
a81893c407 add tests 2026-01-21 13:55:38 +01:00
yiyixuxu
eb221d5bc1 up up 2026-01-21 12:43:37 +01:00
yiyixuxu
a232cd9d30 up 2026-01-21 12:29:12 +01:00
yiyixuxu
1c500c8eeb flux2-dev work in modular setting 2026-01-21 08:06:32 +01:00
yiyixuxu
f49c68cecf style 2026-01-21 01:01:56 +01:00
yiyixuxu
5c1fc4489f remove guidannce to its own block 2026-01-21 00:59:56 +01:00
yiyixuxu
e13377e841 style 2026-01-20 23:14:59 +01:00
yiyixuxu
dea47aa24c Merge branch 'modular-klein' of github.com:huggingface/diffusers into modular-klein 2026-01-20 23:14:05 +01:00
yiyixuxu
c10041e57e a few fix: unpack latents before decoder etc 2026-01-20 23:13:53 +01:00
YiYi Xu
e1e162918d Merge branch 'main' into modular-klein 2026-01-20 10:37:49 -10:00
YiYi Xu
d2953678ad Update src/diffusers/modular_pipelines/flux2/encoders.py 2026-01-20 09:25:29 -10:00
YiYi Xu
3c7494a651 Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2026-01-20 08:09:03 -10:00
yiyi@huggingface.co
9357d8f4f7 copies 2026-01-20 01:32:08 +00:00
yiyi@huggingface.co
fb2cb18f73 style 2026-01-20 01:31:41 +00:00
yiyi@huggingface.co
618a8a9897 support klein 2026-01-20 01:20:14 +00:00
15 changed files with 1412 additions and 115 deletions

View File

@@ -413,6 +413,9 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
@@ -1146,6 +1149,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,

View File

@@ -54,7 +54,10 @@ else:
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -81,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -43,7 +43,7 @@ else:
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
_import_structure["modular_blocks_flux2"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
@@ -51,10 +51,11 @@ else:
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2CoreDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
from .modular_blocks_flux2 import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
@@ -93,10 +94,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2CoreDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -129,17 +129,9 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
@@ -151,13 +143,12 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
device = components._execution_device
scheduler = components.scheduler
@@ -183,7 +174,7 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
@@ -191,11 +182,6 @@ class Flux2SetTimestepsStep(ModularPipelineBlocks):
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
@@ -353,7 +339,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
]
@property
@@ -365,12 +350,6 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@staticmethod
@@ -403,6 +382,72 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
return components, state
class Flux2KleinBaseRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2-Klein base model denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="negative_prompt_embeds", required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="negative_txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
block_state.negative_txt_ids = None
if block_state.negative_prompt_embeds is not None:
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@@ -506,3 +551,42 @@ class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareGuidanceStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the guidance scale tensor for Flux2 inference"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("guidance_scale", default=4.0),
InputParam("num_images_per_prompt", default=1),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,29 +29,16 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
class Flux2UnpackLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
return "Step that unpacks the latents from the denoising step"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
@@ -70,9 +57,9 @@ class Flux2DecodeStep(ModularPipelineBlocks):
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
"latents",
type_hint=torch.Tensor,
description="The denoise latents from denoising step, unpacked with position IDs.",
)
]
@@ -107,6 +94,62 @@ class Flux2DecodeStep(ModularPipelineBlocks):
return torch.stack(x_list, dim=0)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
block_state.latents = latents
self.set_block_state(state, block_state)
return components, state
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
@@ -121,26 +164,20 @@ class Flux2DecodeStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = block_state.latents
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -16,6 +16,8 @@ from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
@@ -25,8 +27,8 @@ from ..modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
if is_torch_xla_available():
@@ -134,6 +136,229 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
return components, block_state
# same as Flux2LoopDenoiser but guidance=None
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
# support CFG for Flux2-Klein base model
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", Flux2Transformer2DModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"negative_prompt_embeds",
required=False,
type_hint=torch.Tensor,
description="Negative text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"negative_txt_ids",
required=False,
type_hint=torch.Tensor,
description="4D position IDs for negative text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"txt_ids": (
getattr(block_state, "txt_ids", None),
getattr(block_state, "negative_txt_ids", None),
),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
components.guider.cleanup_models(components.transformer)
# perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@@ -250,3 +475,35 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinBaseLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -15,13 +15,15 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,10 +81,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
@@ -99,14 +99,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -165,10 +158,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -205,7 +194,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
@@ -222,15 +210,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
@@ -244,10 +225,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -270,6 +247,289 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
return components, state
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=True),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2KleinBaseTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Negative text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
if components.requires_unconditional_embeds:
negative_prompt = [""] * len(prompt)
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -47,7 +47,7 @@ class Flux2TextInputStep(ModularPipelineBlocks):
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
]
@@ -89,6 +89,90 @@ class Flux2TextInputStep(ModularPipelineBlocks):
return components, state
class Flux2KleinBaseTextInputStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return (
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
required=False,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Negative text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2PrepareGuidanceStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
@@ -41,7 +47,6 @@ Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
@@ -72,33 +77,56 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
)
Flux2BeforeDenoiseBlocks = InsertableDict(
Flux2CoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
block_classes = Flux2CoreDenoiseBlocks.values()
block_names = Flux2CoreDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
return (
"Core denoise step that performs the denoising process for Flux2-dev.\n"
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -107,10 +135,8 @@ AUTO_BLOCKS = InsertableDict(
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -130,6 +156,16 @@ class Flux2AutoBlocks(SequentialPipelineBlocks):
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
TEXT2IMAGE_BLOCKS = InsertableDict(
[
@@ -137,8 +173,10 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
@@ -152,8 +190,10 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)

View File

@@ -0,0 +1,232 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2KleinBaseRoPEInputsStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinBaseTextEncoderStep,
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2KleinBaseTextInputStep,
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
################
# VAE encoder
################
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2KleinVaeEncoderBlocks.values()
block_names = Flux2KleinVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
###
### Core denoise
###
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return (
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
###
### Auto blocks
###
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinBaseTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -55,3 +57,56 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),

View File

@@ -17,6 +17,51 @@ class Flux2AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2KleinAutoBlocks,
Flux2KleinModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-modular"
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return

View File

@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
params = frozenset(["prompt", "height", "width"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2KleinModularPipeline
pipeline_blocks_class = Flux2KleinBaseAutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein-base-modular"
params = frozenset(["prompt", "height", "width", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return