Compare commits

...

4 Commits

Author SHA1 Message Date
YiYi Xu
3c7494a651 Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2026-01-20 08:09:03 -10:00
yiyi@huggingface.co
9357d8f4f7 copies 2026-01-20 01:32:08 +00:00
yiyi@huggingface.co
fb2cb18f73 style 2026-01-20 01:31:41 +00:00
yiyi@huggingface.co
618a8a9897 support klein 2026-01-20 01:20:14 +00:00
12 changed files with 755 additions and 42 deletions

View File

@@ -413,6 +413,9 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
@@ -1146,6 +1149,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,

View File

@@ -54,7 +54,10 @@ else:
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -81,7 +84,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .flux2 import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -43,7 +43,7 @@ else:
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
_import_structure["modular_blocks_flux2"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
@@ -54,7 +54,8 @@ else:
"Flux2BeforeDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -85,7 +86,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
from .modular_blocks_flux2 import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
@@ -96,7 +97,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2BeforeDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -353,7 +353,7 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
InputParam(name="negative_prompt_embeds", required=False),
]
@property
@@ -366,10 +366,10 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
name="negative_txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
description="4D position IDs (T, H, W, L) for negative text tokens, used for RoPE calculation.",
),
]
@@ -399,6 +399,11 @@ class Flux2RoPEInputsStep(ModularPipelineBlocks):
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
block_state.negative_txt_ids = None
if block_state.negative_prompt_embeds is not None:
block_state.negative_txt_ids = self._prepare_text_ids(block_state.negative_prompt_embeds)
block_state.negative_txt_ids = block_state.negative_txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state

View File

@@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Tuple
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
@@ -25,8 +28,8 @@ from ..modular_pipeline import (
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
if is_torch_xla_available():
@@ -134,6 +137,241 @@ class Flux2LoopDenoiser(ModularPipelineBlocks):
return components, block_state
# same as Flux2LoopDenoiser but guidance=None
class Flux2KleinLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
# support CFG for Flux2-Klein base model
class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("transformer", Flux2Transformer2DModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 4.0}),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Qwen3",
),
InputParam(
"negative_prompt_embeds",
required=False,
type_hint=torch.Tensor,
description="Negative text embeddings from Qwen3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"negative_txt_ids",
required=False,
type_hint=torch.Tensor,
description="4D position IDs for negative text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
guider_inputs = {
"encoder_hidden_states": (
getattr(block_state, "prompt_embeds", None),
getattr(block_state, "negative_prompt_embeds", None),
),
"txt_ids": (
getattr(block_state, "txt_ids", None),
getattr(block_state, "negative_txt_ids", None),
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
additional_cond_kwargs = {}
for field_name, field_value in block_state.denoiser_input_fields.items():
if field_name in transformer_args and field_name not in guider_inputs:
additional_cond_kwargs[field_name] = field_value
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
cond_kwargs.update(additional_cond_kwargs)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)]
components.guider.cleanup_models(components.transformer)
# perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@@ -220,6 +458,8 @@ class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
block_state.additional_cond_kwargs = {}
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
@@ -250,3 +490,35 @@ class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)
class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2KleinBaseLoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -15,13 +15,13 @@
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from transformers import AutoProcessor, Mistral3ForConditionalGeneration, Qwen2TokenizerFast, Qwen3ForCausalLM
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,10 +79,8 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
@@ -99,14 +97,7 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
@@ -165,10 +156,6 @@ class Flux2TextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -205,7 +192,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
@@ -222,15 +208,8 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
@@ -244,10 +223,6 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
@@ -270,6 +245,153 @@ class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
return components, state
class Flux2KleinTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2-klein"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Qwen3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Qwen3ForCausalLM),
ComponentSpec("tokenizer", Qwen2TokenizerFast),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec(name="is_distilled", default=False),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(9, 18, 27), required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from qwen3 used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
description="Negative text embeddings from qwen3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
hidden_states_layers: List[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2KleinModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
if components.requires_unconditional_embeds:
negative_prompt = ""
block_state.negative_prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=negative_prompt,
device=device,
max_sequence_length=block_state.max_sequence_length,
hidden_states_layers=block_state.text_encoder_out_layers,
)
else:
block_state.negative_prompt_embeds = None
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"

View File

@@ -47,7 +47,14 @@ class Flux2TextInputStep(ModularPipelineBlocks):
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
),
InputParam(
"negative_prompt_embeds",
required=False,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
),
]
@@ -70,6 +77,12 @@ class Flux2TextInputStep(ModularPipelineBlocks):
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Negative text embeddings used to guide the image generation",
),
]
@torch.no_grad()
@@ -85,6 +98,15 @@ class Flux2TextInputStep(ModularPipelineBlocks):
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,171 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2KleinVaeEncoderBlocks.values()
block_names = Flux2KleinVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
]
)
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein.\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
)
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
)
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)

View File

@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -55,3 +57,56 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2-Klein.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds

View File

@@ -59,6 +59,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),

View File

@@ -17,6 +17,51 @@ class Flux2AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]