Compare commits

...

7 Commits

Author SHA1 Message Date
Sayak Paul
764f7ede33 [core] Flux2 klein kv followups (#13264)
* implement Flux2Transformer2DModelOutput.

* add output class to docs.

* add Flux2KleinKV to docs.

* add pipeline tests for klein kv.
2026-03-13 10:05:11 +05:30
Sayak Paul
8d0f3e1ba8 [lora] fix z-image non-diffusers lora loading. (#13255)
fix z-image non-diffusers lora loading.
2026-03-13 06:58:53 +05:30
huemin
094caf398f klein 9b kv (#13262)
* klein 9b kv

* Apply style fixes

* fix typo inline modulation split

* make fix-copies

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-03-12 06:53:56 -10:00
Alvaro Bartolome
81c354d879 Add PRXPipeline in AUTO_TEXT2IMAGE_PIPELINES_MAPPING (#13257) 2026-03-11 14:39:24 -03:00
Miguel Martin
0a2c26d0a4 Update Documentation for NVIDIA Cosmos (#13251)
* fix docs

* update main example
2026-03-11 09:14:56 -07:00
Dhruv Nair
07c5ba8eee [Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-11 16:42:11 +05:30
Dhruv Nair
897aed72fa [Quantization] Deprecate Quanto (#13180)
* update

* update
2026-03-11 09:26:46 +05:30
20 changed files with 1739 additions and 54 deletions

View File

@@ -532,8 +532,6 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
@@ -677,6 +675,8 @@
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios

View File

@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
## Flux2Transformer2DModel
[[autodoc]] Flux2Transformer2DModel
## Flux2Transformer2DModelOutput
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput

View File

@@ -21,29 +21,31 @@
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
## Basic usage
```python
import torch
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
from diffusers import Cosmos2_5_PredictBasePipeline
from diffusers.utils import export_to_video
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
model_id = "nvidia/Cosmos-Predict2.5-2B"
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
)
pipe.to("cuda")
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
image=None,
video=None,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=93,
generator=torch.Generator().manual_seed(1),
).frames[0]
export_to_video(output, "text2world.mp4", fps=16)
```
## Cosmos2_5_TransferPipeline

View File

@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2KleinPipeline
[[autodoc]] Flux2KleinPipeline
- all
- __call__
## Flux2KleinKVPipeline
[[autodoc]] Flux2KleinKVPipeline
- all
- __call__

View File

@@ -44,6 +44,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cosmos](cosmos) | text2video, video2video |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
| [DDPM](ddpm) | unconditional image generation |

View File

@@ -510,6 +510,7 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"Flux2KleinKVPipeline",
"Flux2KleinPipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
@@ -1266,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Flux2KleinKVPipeline,
Flux2KleinPipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,

View File

@@ -2538,8 +2538,12 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
alpha_tensor = state_dict.pop(alpha_key, None)
if alpha_tensor is None:
return 1.0, 1.0
scale = (
alpha_tensor.item() / rank
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any
import torch
@@ -21,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import BaseOutput, apply_lora_scale, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -32,7 +33,6 @@ from ..embeddings import (
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
@@ -40,6 +40,216 @@ from ..normalization import AdaLayerNormContinuous
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Flux2Transformer2DModelOutput(BaseOutput):
"""
The output of [`Flux2Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on the `encoder_hidden_states` input.
kv_cache (`Flux2KVCache`, *optional*):
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
"""
sample: "torch.Tensor" # noqa: F821
kv_cache: "Flux2KVCache | None" = None
class Flux2KVLayerCache:
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
format: (batch_size, num_ref_tokens, num_heads, head_dim).
"""
def __init__(self):
self.k_ref: torch.Tensor | None = None
self.v_ref: torch.Tensor | None = None
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
"""Store reference token K/V."""
self.k_ref = k_ref
self.v_ref = v_ref
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Retrieve cached reference token K/V."""
if self.k_ref is None:
raise RuntimeError("KV cache has not been populated yet.")
return self.k_ref, self.v_ref
def clear(self):
self.k_ref = None
self.v_ref = None
class Flux2KVCache:
"""Container for all layers' reference-token KV caches.
Holds separate cache lists for double-stream and single-stream transformer blocks.
"""
def __init__(self, num_double_layers: int, num_single_layers: int):
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
self.num_ref_tokens: int = 0
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
return self.double_block_caches[layer_idx]
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
return self.single_block_caches[layer_idx]
def clear(self):
for cache in self.double_block_caches:
cache.clear()
for cache in self.single_block_caches:
cache.clear()
self.num_ref_tokens = 0
def _flux2_kv_causal_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
num_txt_tokens: int,
num_ref_tokens: int,
kv_cache: Flux2KVLayerCache | None = None,
backend=None,
) -> torch.Tensor:
"""Causal attention for KV caching where reference tokens only self-attend.
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
between txt and img.
"""
# No ref tokens and no cache — standard full attention
if num_ref_tokens == 0 and kv_cache is None:
return dispatch_attention_fn(query, key, value, backend=backend)
if kv_cache is not None:
# Cached mode: inject ref K/V between txt and img
k_ref, v_ref = kv_cache.get()
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
# Extract mode: ref tokens self-attend, txt+img attend to all
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
q_txt = query[:, :ref_start]
q_ref = query[:, ref_start:ref_end]
q_img = query[:, ref_end:]
k_txt = key[:, :ref_start]
k_ref = key[:, ref_start:ref_end]
k_img = key[:, ref_end:]
v_txt = value[:, :ref_start]
v_ref = value[:, ref_start:ref_end]
v_img = value[:, ref_end:]
# txt+img attend to all tokens
q_txt_img = torch.cat([q_txt, q_img], dim=1)
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
attn_txt = attn_txt_img[:, :ref_start]
attn_img = attn_txt_img[:, ref_start:]
# ref tokens self-attend only
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
def _blend_mod_params(
img_params: tuple[torch.Tensor, ...],
ref_params: tuple[torch.Tensor, ...],
num_ref: int,
seq_len: int,
) -> tuple[torch.Tensor, ...]:
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
blended = []
for im, rm in zip(img_params, ref_params):
if im.ndim == 2:
im = im.unsqueeze(1)
rm = rm.unsqueeze(1)
B = im.shape[0]
blended.append(
torch.cat(
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
dim=1,
)
)
return tuple(blended)
def _blend_double_block_mods(
img_mod: torch.Tensor,
ref_mod: torch.Tensor,
num_ref: int,
seq_len: int,
) -> torch.Tensor:
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
with `Flux2Modulation.split(mod, 2)`.
"""
if img_mod.ndim == 2:
img_mod = img_mod.unsqueeze(1)
ref_mod = ref_mod.unsqueeze(1)
img_chunks = torch.chunk(img_mod, 6, dim=-1)
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
img_mods = (img_chunks[0:3], img_chunks[3:6])
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
all_params = []
for img_set, ref_set in zip(img_mods, ref_mods):
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
all_params.extend(blended)
return torch.cat(all_params, dim=-1)
def _blend_single_block_mods(
single_mod: torch.Tensor,
ref_mod: torch.Tensor,
num_txt: int,
num_ref: int,
seq_len: int,
) -> torch.Tensor:
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
"""
if single_mod.ndim == 2:
single_mod = single_mod.unsqueeze(1)
ref_mod = ref_mod.unsqueeze(1)
img_params = torch.chunk(single_mod, 3, dim=-1)
ref_params = torch.chunk(ref_mod, 3, dim=-1)
blended = []
for im, rm in zip(img_params, ref_params):
if im.ndim == 2:
im = im.unsqueeze(1)
rm = rm.unsqueeze(1)
B = im.shape[0]
im_expanded = im.expand(B, seq_len, -1)
rm_expanded = rm.expand(B, num_ref, -1)
blended.append(
torch.cat(
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
dim=1,
)
)
return torch.cat(blended, dim=-1)
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
@@ -181,9 +391,108 @@ class Flux2AttnProcessor:
return hidden_states
class Flux2KVAttnProcessor:
"""
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
kv_cache: Flux2KVLayerCache | None = None,
kv_cache_mode: str | None = None,
num_ref_tokens: int = 0,
) -> torch.Tensor:
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
attn, hidden_states, encoder_hidden_states
)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if attn.added_kv_proj_dim is not None:
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
encoder_query = attn.norm_added_q(encoder_query)
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=1)
key = torch.cat([encoder_key, key], dim=1)
value = torch.cat([encoder_value, value], dim=1)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
# Extract ref K/V from the combined sequence
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
# Dispatch attention
if kv_cache_mode == "extract" and num_ref_tokens > 0:
hidden_states = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
)
elif kv_cache_mode == "cached" and kv_cache is not None:
hidden_states = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
)
else:
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = Flux2AttnProcessor
_available_processors = [Flux2AttnProcessor]
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
def __init__(
self,
@@ -312,6 +621,90 @@ class Flux2ParallelSelfAttnProcessor:
return hidden_states
class Flux2KVParallelSelfAttnProcessor:
"""
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
identically to `Flux2ParallelSelfAttnProcessor`.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
def __call__(
self,
attn: "Flux2ParallelSelfAttention",
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
kv_cache: Flux2KVLayerCache | None = None,
kv_cache_mode: str | None = None,
num_txt_tokens: int = 0,
num_ref_tokens: int = 0,
) -> torch.Tensor:
# Parallel in (QKV + MLP in) projection
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
qkv, mlp_hidden_states = torch.split(
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
)
query, key, value = qkv.chunk(3, dim=-1)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
query = attn.norm_q(query)
key = attn.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
# Extract ref K/V from the combined sequence
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
ref_start = num_txt_tokens
ref_end = num_txt_tokens + num_ref_tokens
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
# Dispatch attention
if kv_cache_mode == "extract" and num_ref_tokens > 0:
attn_output = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
)
elif kv_cache_mode == "cached" and kv_cache is not None:
attn_output = _flux2_kv_causal_attention(
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
)
else:
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
attn_output = attn_output.flatten(2, 3)
attn_output = attn_output.to(query.dtype)
# Handle the feedforward (FF) logic
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
# Concatenate and parallel output projection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
hidden_states = attn.to_out(hidden_states)
return hidden_states
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
"""
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
@@ -322,7 +715,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
"""
_default_processor_cls = Flux2ParallelSelfAttnProcessor
_available_processors = [Flux2ParallelSelfAttnProcessor]
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
# Does not support QKV fusion as the QKV projections are always fused
_supports_qkv_fusion = False
@@ -780,6 +1173,8 @@ class Flux2Transformer2DModel(
self.gradient_checkpointing = False
_skip_keys = ["kv_cache"]
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
@@ -791,19 +1186,21 @@ class Flux2Transformer2DModel(
guidance: torch.Tensor = None,
joint_attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor | Transformer2DModelOutput:
kv_cache: "Flux2KVCache | None" = None,
kv_cache_mode: str | None = None,
num_ref_tokens: int = 0,
ref_fixed_timestep: float = 0.0,
) -> torch.Tensor | Flux2Transformer2DModelOutput:
"""
The [`FluxTransformer2DModel`] forward method.
The [`Flux2Transformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
timestep ( `torch.LongTensor`):
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -811,13 +1208,23 @@ class Flux2Transformer2DModel(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
kv_cache (`Flux2KVCache`, *optional*):
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
returned. When "cached", the provided cache is used to inject ref K/V during attention.
kv_cache_mode (`str`, *optional*):
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
`None`, standard forward pass without KV caching.
num_ref_tokens (`int`, defaults to `0`):
Number of reference image tokens prepended to `hidden_states` (only used when
`kv_cache_mode="extract"`).
ref_fixed_timestep (`float`, defaults to `0.0`):
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
populated `Flux2KVCache`.
"""
# 0. Handle input arguments
num_txt_tokens = encoder_hidden_states.shape[1]
# 1. Calculate timestep embedding and modulation parameters
@@ -832,13 +1239,33 @@ class Flux2Transformer2DModel(
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)
# KV extract mode: create cache and blend modulations for ref tokens
if kv_cache_mode == "extract" and num_ref_tokens > 0:
num_img_tokens = hidden_states.shape[1] # includes ref tokens
kv_cache = Flux2KVCache(
num_double_layers=len(self.transformer_blocks),
num_single_layers=len(self.single_transformer_blocks),
)
kv_cache.num_ref_tokens = num_ref_tokens
# Ref tokens use a fixed timestep for modulation
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
ref_single_mod = self.single_stream_modulation(ref_temb)
# Blend double block img modulation: [ref_mod, img_mod]
double_stream_mod_img = _blend_double_block_mods(
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
)
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
# 3. Calculate RoPE embeddings from image and text tokens
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
# text prompts of differents lengths. Is this a use case we want to support?
if img_ids.ndim == 3:
img_ids = img_ids[0]
if txt_ids.ndim == 3:
@@ -851,8 +1278,29 @@ class Flux2Transformer2DModel(
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
)
# 4. Double Stream Transformer Blocks
# 4. Build joint_attention_kwargs with KV cache info
if kv_cache_mode == "extract":
kv_attn_kwargs = {
**(joint_attention_kwargs or {}),
"kv_cache": None,
"kv_cache_mode": "extract",
"num_ref_tokens": num_ref_tokens,
}
elif kv_cache_mode == "cached" and kv_cache is not None:
kv_attn_kwargs = {
**(joint_attention_kwargs or {}),
"kv_cache": None,
"kv_cache_mode": "cached",
"num_ref_tokens": kv_cache.num_ref_tokens,
}
else:
kv_attn_kwargs = joint_attention_kwargs
# 5. Double Stream Transformer Blocks
for index_block, block in enumerate(self.transformer_blocks):
if kv_cache_mode is not None and kv_cache is not None:
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
@@ -861,7 +1309,7 @@ class Flux2Transformer2DModel(
double_stream_mod_img,
double_stream_mod_txt,
concat_rotary_emb,
joint_attention_kwargs,
kv_attn_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
@@ -870,13 +1318,30 @@ class Flux2Transformer2DModel(
temb_mod_img=double_stream_mod_img,
temb_mod_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
joint_attention_kwargs=kv_attn_kwargs,
)
# Concatenate text and image streams for single-block inference
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 5. Single Stream Transformer Blocks
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
if kv_cache_mode == "extract" and num_ref_tokens > 0:
total_single_len = hidden_states.shape[1]
single_stream_mod = _blend_single_block_mods(
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
)
# Build single-block KV kwargs (single blocks need num_txt_tokens)
if kv_cache_mode is not None:
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
else:
kv_attn_kwargs_single = kv_attn_kwargs
# 6. Single Stream Transformer Blocks
for index_block, block in enumerate(self.single_transformer_blocks):
if kv_cache_mode is not None and kv_cache is not None:
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
@@ -884,7 +1349,7 @@ class Flux2Transformer2DModel(
None,
single_stream_mod,
concat_rotary_emb,
joint_attention_kwargs,
kv_attn_kwargs_single,
)
else:
hidden_states = block(
@@ -892,16 +1357,25 @@ class Flux2Transformer2DModel(
encoder_hidden_states=None,
temb_mod=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
joint_attention_kwargs=kv_attn_kwargs_single,
)
# Remove text tokens from concatenated stream
hidden_states = hidden_states[:, num_txt_tokens:, ...]
# 6. Output layers
# Remove text tokens (and ref tokens in extract mode) from concatenated stream
if kv_cache_mode == "extract" and num_ref_tokens > 0:
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
else:
hidden_states = hidden_states[:, num_txt_tokens:, ...]
# 7. Output layers
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if kv_cache_mode == "extract":
if not return_dict:
return (output, kv_cache)
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
return Flux2Transformer2DModelOutput(sample=output)

View File

@@ -129,7 +129,7 @@ else:
]
_import_structure["bria"] = ["BriaPipeline"]
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hidream_image import HiDreamImagePipeline

View File

@@ -95,6 +95,7 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
@@ -185,6 +186,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
("prx", PRXPipeline),
]
)

View File

@@ -82,13 +82,16 @@ EXAMPLE_DOC_STRING = """
```python
>>> import cv2
>>> import numpy as np
>>> from PIL import Image
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> controlnet = AutoModel.from_pretrained(
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
... )
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )

View File

@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_flux2 import Flux2Pipeline
from .pipeline_flux2_klein import Flux2KleinPipeline
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
else:
import sys

View File

@@ -0,0 +1,886 @@
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable
import numpy as np
import PIL
import torch
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
from ...loaders import Flux2LoraLoaderMixin
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .image_processor import Flux2ImageProcessor
from .pipeline_output import Flux2PipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from PIL import Image
>>> from diffusers import Flux2KleinKVPipeline
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> ref_image = Image.open("reference.png")
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
>>> image.save("flux2_kv_output.png")
```
"""
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`list[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`list[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
r"""
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
inference when using reference images.
Reference:
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
Args:
transformer ([`Flux2Transformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLFlux2`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`Qwen3ForCausalLM`]):
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
tokenizer (`Qwen2TokenizerFast`):
Tokenizer of class
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLFlux2,
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
transformer: Flux2Transformer2DModel,
is_distilled: bool = True,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 512
self.default_sample_size = 128
# Set KV-cache-aware attention processors
self._set_kv_attn_processors()
@staticmethod
def _get_qwen3_prompt_embeds(
text_encoder: Qwen3ForCausalLM,
tokenizer: Qwen2TokenizerFast,
prompt: str | list[str],
dtype: torch.dtype | None = None,
device: torch.device | None = None,
max_sequence_length: int = 512,
hidden_states_layers: list[int] = (9, 18, 27),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
all_input_ids = []
all_attention_masks = []
for single_prompt in prompt:
messages = [{"role": "user", "content": single_prompt}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
inputs = tokenizer(
text,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
all_input_ids.append(inputs["input_ids"])
all_attention_masks.append(inputs["attention_mask"])
input_ids = torch.cat(all_input_ids, dim=0).to(device)
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
# Forward pass through the model
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
# Only use outputs from intermediate layers and stack them
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
def _prepare_text_ids(
x: torch.Tensor, # (B, L, D) or (L, D)
t_coord: torch.Tensor | None = None,
):
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, l)
out_ids.append(coords)
return torch.stack(out_ids)
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
def _prepare_latent_ids(
latents: torch.Tensor, # (B, C, H, W)
):
r"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents (torch.Tensor):
Latent tensor of shape (B, C, H, W)
Returns:
torch.Tensor:
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
H=[0..H-1], W=[0..W-1], L=0
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1) # [0] - time dimension
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1) # [0] - layer dimension
# Create position IDs: (H*W, 4)
latent_ids = torch.cartesian_prod(t, h, w, l)
# Expand to batch: (B, H*W, 4)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
def _prepare_image_ids(
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
scale: int = 10,
):
r"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
This function creates a unique coordinate for every pixel/patch across all input latent with different
dimensions.
Args:
image_latents (list[torch.Tensor]):
A list of image latent feature tensors, typically of shape (C, H, W).
scale (int, optional):
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
latent is: 'scale + scale * i'. Defaults to 10.
Returns:
torch.Tensor:
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
input latents.
Coordinate Components (Dimension 4):
- T (Time): The unique index indicating which latent image the coordinate belongs to.
- H (Height): The row index within that latent image.
- W (Width): The column index within that latent image.
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
# create time offset for each reference image
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
def _patchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
def _unpatchify_latents(latents):
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
def _pack_latents(latents):
"""
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
"""
using position ids to scatter tokens into place
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
def _set_kv_attn_processors(self):
"""Replace default attention processors with KV-cache-aware variants."""
for block in self.transformer.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in self.transformer.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
def encode_prompt(
self,
prompt: str | list[str],
device: torch.device | None = None,
num_images_per_prompt: int = 1,
prompt_embeds: torch.Tensor | None = None,
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int] = (9, 18, 27),
):
device = device or self._execution_device
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_embeds = self._get_qwen3_prompt_embeds(
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
prompt=prompt,
device=device,
max_sequence_length=max_sequence_length,
hidden_states_layers=text_encoder_out_layers,
)
batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
text_ids = self._prepare_text_ids(prompt_embeds)
text_ids = text_ids.to(device)
return prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_latents_channels,
height,
width,
dtype,
device,
generator: torch.Generator,
latents: torch.Tensor | None = None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(device)
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
return latents, latent_ids
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
def prepare_image_latents(
self,
images: list[torch.Tensor],
batch_size,
generator: torch.Generator,
device,
dtype,
):
image_latents = []
for image in images:
image = image.to(device=device, dtype=dtype)
imagge_latent = self._encode_vae_image(image=image, generator=generator)
image_latents.append(imagge_latent) # (1, 128, 32, 32)
image_latent_ids = self._prepare_image_ids(image_latents)
# Pack each latent and concatenate
packed_latents = []
for latent in image_latents:
# latent: (1, 128, 32, 32)
packed = self._pack_latents(latent) # (1, 1024, 128)
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
packed_latents.append(packed)
# Concatenate all reference tokens along sequence dimension
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
return image_latents, image_latent_ids
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if (
height is not None
and height % (self.vae_scale_factor * 2) != 0
or width is not None
and width % (self.vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
prompt: str | list[str] = None,
height: int | None = None,
width: int | None = None,
num_inference_steps: int = 4,
sigmas: list[float] | None = None,
num_images_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
text_encoder_out_layers: tuple[int] = (9, 18, 27),
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
recomputing.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation.
height (`int`, *optional*):
The height in pixels of the generated image.
width (`int`, *optional*):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 4):
The number of denoising steps.
sigmas (`List[float]`, *optional*):
Custom sigmas for the denoising schedule.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
Generator(s) for deterministic generation.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
Output format: `"pil"` or `"np"`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a `Flux2PipelineOutput` or a plain tuple.
attention_kwargs (`dict`, *optional*):
Extra kwargs passed to attention processors.
callback_on_step_end (`Callable`, *optional*):
Callback function called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List`, *optional*):
Tensor inputs for the callback function.
max_sequence_length (`int`, defaults to 512):
Maximum sequence length for the prompt.
text_encoder_out_layers (`tuple[int]`):
Layer indices for text encoder hidden state extraction.
Examples:
Returns:
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
"""
# 1. Check inputs
self.check_inputs(
prompt=prompt,
height=height,
width=width,
prompt_embeds=prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. prepare text embeddings
prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_embeds=prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)
# 4. process images
if image is not None and not isinstance(image, list):
image = [image]
condition_images = None
if image is not None:
for img in image:
self.image_processor.check_image_input(img)
condition_images = []
for img in image:
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = self.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
condition_images.append(img)
height = height or image_height
width = width or image_width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 5. prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_ids = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_latents_channels=num_channels_latents,
height=height,
width=width,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
)
image_latents = None
image_latent_ids = None
if condition_images is not None:
image_latents, image_latent_ids = self.prepare_image_latents(
images=condition_images,
batch_size=batch_size * num_images_per_prompt,
generator=generator,
device=device,
dtype=self.vae.dtype,
)
# 6. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
sigmas = None
image_seq_len = latents.shape[1]
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 7. Denoising loop with KV caching
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
# No ref images: standard forward
self.scheduler.set_begin_index(0)
kv_cache = None
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if i == 0 and image_latents is not None:
# Step 0: include ref tokens, extract KV cache
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
noise_pred, kv_cache = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
kv_cache_mode="extract",
num_ref_tokens=image_latents.shape[1],
)
elif kv_cache is not None:
# Steps 1+: use cached ref KV, no ref tokens in input
noise_pred = self.transformer(
hidden_states=latents.to(self.transformer.dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
kv_cache=kv_cache,
kv_cache_mode="cached",
)[0]
else:
# No reference images: standard forward
noise_pred = self.transformer(
hidden_states=latents.to(self.transformer.dtype),
timestep=timestep / 1000,
guidance=None,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# Clean up KV cache
if kv_cache is not None:
kv_cache.clear()
self._current_timestep = None
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
if output_type == "latent":
image = latents
else:
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return Flux2PipelineOutput(images=image)

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -1202,6 +1202,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinKVPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -0,0 +1,174 @@
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
from diffusers import (
AutoencoderKLFlux2,
FlowMatchEulerDiscreteScheduler,
Flux2KleinKVPipeline,
Flux2Transformer2DModel,
)
from ...testing_utils import torch_device
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Flux2KleinKVPipeline
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
batch_params = frozenset(["prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
supports_dduf = False
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = Flux2Transformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=16,
timestep_guidance_channels=256,
axes_dims_rope=[4, 4, 4, 4],
guidance_embeds=False,
)
# Create minimal Qwen3 config
config = Qwen3Config(
intermediate_size=16,
hidden_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
torch.manual_seed(0)
text_encoder = Qwen3ForCausalLM(config)
# Use a simple tokenizer for testing
tokenizer = Qwen2TokenizerFast.from_pretrained(
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
)
torch.manual_seed(0)
vae = AutoencoderKLFlux2(
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "a dog is dancing",
"image": Image.new("RGB", (64, 64)),
"generator": generator,
"num_inference_steps": 2,
"height": 8,
"width": 8,
"max_sequence_length": 64,
"output_type": "np",
"text_encoder_out_layers": (1,),
}
return inputs
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
self.assertTrue(
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
self.assertTrue(
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
("Fusion of QKV projections shouldn't affect the outputs."),
)
self.assertTrue(
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
self.assertTrue(
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
("Original outputs should match when fused QKV projections are disabled."),
)
def test_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
self.assertEqual(
(output_height, output_width),
(expected_height, expected_width),
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
)
def test_without_image(self):
device = "cpu"
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
inputs = self.get_dummy_inputs(device)
del inputs["image"]
image = pipe(**inputs).images
self.assertEqual(image.shape, (1, 8, 8, 3))
@unittest.skip("Needs to be revisited")
def test_encode_prompt_works_in_isolation(self):
pass