Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
9bd83616bf Merge branch 'main' into enable-cp-kernels 2025-12-10 12:33:18 +08:00
sayakpaul
f732ff1144 up 2025-12-09 15:30:33 +05:30
sayakpaul
7a8f85b047 up 2025-12-09 14:59:01 +05:30
sayakpaul
82d20e64a5 up 2025-12-09 14:39:07 +05:30
22 changed files with 429 additions and 2065 deletions

View File

@@ -237,8 +237,6 @@ By selectively loading and unloading the models you need at a given stage and sh
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
### Ring Attention
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -247,60 +245,40 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
```py
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
try:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
return device
def main():
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
pipeline.transformer.set_attention_backend("flash")
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
if rank == 0:
image.save("output.png")
if dist.get_rank() == 0:
image.save(f"output.png")
except Exception as e:
print(f"An error occurred: {e}")
torch.distributed.breakpoint()
raise
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
```
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
/```shell
`torchrun --nproc-per-node 2 above_script.py`.
/```
### Ulysses Attention
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -310,26 +288,5 @@ The script above needs to be run with a distributed launcher, such as [torchrun]
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
```py
# Depending on the number of GPUs available.
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.
```py
CKPT_ID = "black-forest-labs/FLUX.1-dev"
cp_config = ContextParallelConfig(ring_degree=2)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```

View File

@@ -404,8 +404,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
@@ -1113,8 +1111,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,

View File

@@ -256,6 +256,10 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -270,7 +274,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -594,22 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -969,6 +994,231 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1015,6 +1265,46 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
# ===== Context parallel =====
@@ -1372,7 +1662,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1386,17 +1676,35 @@ def _flash_attention_hub(
) -> torch.Tensor:
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -1539,7 +1847,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -1553,31 +1861,65 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2107,7 +2449,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2132,6 +2474,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -47,7 +47,6 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
logger = logging.get_logger(__name__)
@@ -430,12 +429,8 @@ def _load_shard_files_with_threadpool(
low_cpu_mem_usage=low_cpu_mem_usage,
)
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
tqdm_kwargs["disable"] = True
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(**tqdm_kwargs) as pbar:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()

View File

@@ -59,8 +59,11 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.hub_utils import (
PushToHubMixin,
load_or_create_model_card,
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
@@ -1669,10 +1672,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
shard_tqdm_kwargs["disable"] = True
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
for shard_file in shard_files:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)

View File

@@ -52,10 +52,6 @@ else:
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2ModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
@@ -79,7 +75,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -1,111 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = [
"Flux2TextEncoderStep",
"Flux2RemoteTextEncoderStep",
"Flux2VaeEncoderStep",
]
_import_structure["before_denoise"] = [
"Flux2SetTimestepsStep",
"Flux2PrepareLatentsStep",
"Flux2RoPEInputsStep",
"Flux2PrepareImageLatentsStep",
]
_import_structure["denoise"] = [
"Flux2LoopDenoiser",
"Flux2LoopAfterDenoiser",
"Flux2DenoiseLoopWrapper",
"Flux2DenoiseStep",
]
_import_structure["decoders"] = ["Flux2DecodeStep"]
_import_structure["inputs"] = [
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import (
Flux2DenoiseLoopWrapper,
Flux2DenoiseStep,
Flux2LoopAfterDenoiser,
Flux2LoopDenoiser,
)
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
REMOTE_AUTO_BLOCKS,
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,508 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import numpy as np
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
"""Compute empirical mu for Flux2 timestep scheduling."""
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class Flux2SetTimestepsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
height = block_state.height or components.default_height
width = block_state.width or components.default_width
vae_scale_factor = components.vae_scale_factor
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
timesteps = block_state.timesteps
if timesteps is None and sigmas is None:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
]
@staticmethod
def check_inputs(components, block_state):
vae_scale_factor = components.vae_scale_factor
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
)
@staticmethod
def _prepare_latent_ids(latents: torch.Tensor):
"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents: Latent tensor of shape (B, C, H, W)
Returns:
Position IDs tensor of shape (B, H*W, 4)
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1)
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1)
latent_ids = torch.cartesian_prod(t, h, w, l)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
def prepare_latents(
comp,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
latents = self.prepare_latents(
components,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(block_state.device)
latents = self._pack_latents(latents)
block_state.latents = latents
block_state.latent_ids = latent_ids
self.set_block_state(state, block_state)
return components, state
class Flux2RoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image_latents", type_hint=List[torch.Tensor]),
InputParam("batch_size", required=True, type_hint=int),
InputParam("num_images_per_prompt", default=1, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning",
),
OutputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents",
),
]
@staticmethod
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
Args:
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
scale: Factor used to define the time separation between latents.
Returns:
Combined coordinate tensor of shape (1, N_total, 4)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image_latents = block_state.image_latents
if image_latents is None:
block_state.image_latents = None
block_state.image_latent_ids = None
self.set_block_state(state, block_state)
return components, state
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
image_latent_ids = self._prepare_image_ids(image_latents)
packed_latents = []
for latent in image_latents:
packed = self._pack_latents(latent)
packed = packed.squeeze(0)
packed_latents.append(packed)
image_latents = torch.cat(packed_latents, dim=0)
image_latents = image_latents.unsqueeze(0)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
block_state.image_latents = image_latents
block_state.image_latent_ids = image_latent_ids
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,146 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLFlux2
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="Position IDs for the latents, used for unpacking",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
"""
Unpack latents using position IDs to scatter tokens into place.
Args:
x: Packed latents tensor of shape (B, seq_len, C)
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
Returns:
Unpacked latents tensor of shape (B, C, H, W)
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,252 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2LoopDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Mistral3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that updates the latents after denoising. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoises the latents over `timesteps`. "
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self.set_block_state(state, block_state)
return components, state
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2LoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -1,349 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def format_text_input(prompts: List[str], system_message: str = None):
"""Format prompts for Mistral3 chat template."""
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2TextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
# fmt: off
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
# fmt: on
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
ComponentSpec("tokenizer", AutoProcessor),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from Mistral3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
def _get_mistral_3_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
hidden_states_layers: Tuple[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=block_state.device,
max_sequence_length=block_state.max_sequence_length,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using a remote API endpoint"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from remote API used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
import io
import requests
from huggingface_hub import get_token
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
response = requests.post(
self.REMOTE_URL,
json={"prompt": prompt},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKLFlux2)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("condition_images", type_hint=List[torch.Tensor]),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=List[torch.Tensor],
description="List of latent representations for each reference image",
),
]
@staticmethod
def _patchify_latents(latents):
"""Convert latents to patchified format for Flux2."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
"""Encode a single image using Flux2 VAE with batch norm normalization."""
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
condition_images = block_state.condition_images
if condition_images is None:
return components, state
device = components._execution_device
dtype = components.vae.dtype
image_latents = []
for image in condition_images:
image = image.to(device=device, dtype=dtype)
latent = self._encode_vae_image(
vae=components.vae,
image=image,
generator=block_state.generator,
)
image_latents.append(latent)
block_state.image_latents = image_latents
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,160 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from ...configuration_utils import FrozenDict
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__)
class Flux2TextInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image"),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
images = block_state.image
if images is None:
block_state.condition_images = None
self.set_block_state(state, block_state)
return components, state
if not isinstance(images, list):
images = [images]
condition_images = []
for img in images:
components.image_processor.check_image_input(img)
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = components.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
condition_img = components.image_processor.preprocess(
img, height=image_height, width=image_width, resize_mode="crop"
)
condition_images.append(condition_img)
if block_state.height is None:
block_state.height = image_height
if block_state.width is None:
block_state.width = image_width
block_state.condition_images = condition_images
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,166 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2VaeEncoderBlocks.values()
block_names = Flux2VaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2VaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
Flux2BeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
class Flux2AutoBlocks(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
"- For text-to-image generation, all you need to provide is `prompt`.\n"
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("preprocess_images", Flux2ProcessImagesInputStep()),
("vae_encoder", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
"auto": AUTO_BLOCKS,
"remote": REMOTE_AUTO_BLOCKS,
}

View File

@@ -1,57 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2AutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents

View File

@@ -58,7 +58,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
@@ -1587,6 +1586,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property

View File

@@ -67,7 +67,6 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
@@ -983,11 +982,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
if not is_torch_dist_rank_zero():
logging_tqdm_kwargs["disable"] = True
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 7.1 device_map shenanigans
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1913,14 +1908,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
if iterable is not None:
return tqdm(iterable, **progress_bar_config)
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **progress_bar_config)
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

View File

@@ -1,36 +0,0 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import torch
except ImportError:
torch = None
def is_torch_dist_rank_zero() -> bool:
if torch is None:
return True
dist_module = getattr(torch, "distributed", None)
if dist_module is None or not dist_module.is_available():
return True
if not dist_module.is_initialized():
return True
try:
return dist_module.get_rank() == 0
except (RuntimeError, ValueError):
return True

View File

@@ -2,36 +2,6 @@
from ..utils import DummyObject, requires_backends
class Flux2AutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -32,8 +32,6 @@ from typing import Dict, Optional
from tqdm import auto as tqdm_lib
from .distributed_utils import is_torch_dist_rank_zero
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
@@ -49,23 +47,6 @@ log_levels = {
_default_log_level = logging.WARNING
_tqdm_active = True
_rank_zero_filter = None
class _RankZeroFilter(logging.Filter):
def filter(self, record):
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
global _rank_zero_filter
if _rank_zero_filter is None:
_rank_zero_filter = _RankZeroFilter()
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
logger.addFilter(_rank_zero_filter)
def _get_default_logging_level() -> int:
@@ -109,7 +90,6 @@ def _configure_library_root_logger() -> None:
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
_ensure_rank_zero_filter(library_root_logger)
def _reset_library_root_logger() -> None:
@@ -140,9 +120,7 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
name = _get_library_name()
_configure_library_root_logger()
logger = logging.getLogger(name)
_ensure_rank_zero_filter(logger)
return logger
return logging.getLogger(name)
def get_verbosity() -> int:

View File

@@ -1,93 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return

View File

@@ -165,6 +165,7 @@ class ModularPipelineTesterMixin:
expected_max_diff=1e-4,
):
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
# Reset generator in case it is has been used in self.get_dummy_inputs