mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-13 20:17:53 +08:00
Compare commits
6 Commits
flash-3-hu
...
tests-load
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faca9b90a7 | ||
|
|
a1f63a398c | ||
|
|
bf846f722c | ||
|
|
78a86e85cf | ||
|
|
7673ab1757 | ||
|
|
b7648557d4 |
@@ -17,7 +17,3 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
|
||||
## Flux2Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput
|
||||
|
||||
@@ -41,11 +41,5 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
|
||||
## Flux2KleinPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Flux2KleinKVPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinKVPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -510,7 +510,6 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2KleinKVPipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
@@ -1267,7 +1266,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
|
||||
@@ -2538,12 +2538,8 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha_tensor = state_dict.pop(alpha_key, None)
|
||||
if alpha_tensor is None:
|
||||
return 1.0, 1.0
|
||||
scale = (
|
||||
alpha_tensor.item() / rank
|
||||
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
|
||||
@@ -2559,9 +2559,7 @@ def _flash_attention_3_hub(
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError(
|
||||
"`attn_mask` is not supported for flash-attn 3. Please use the `_flash_3_varlen_hub` backend instead."
|
||||
)
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2643,8 +2641,6 @@ def _flash_attention_3_varlen_hub(
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask > -1
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
@@ -2664,7 +2660,7 @@ def _flash_attention_3_varlen_hub(
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
|
||||
result = func(
|
||||
out, lse, *_ = func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
@@ -2675,11 +2671,6 @@ def _flash_attention_3_varlen_hub(
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(result, tuple):
|
||||
out, lse, *_ = result
|
||||
else:
|
||||
out = result
|
||||
lse = None
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -22,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -33,6 +32,7 @@ from ..embeddings import (
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
@@ -40,216 +40,6 @@ from ..normalization import AdaLayerNormContinuous
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Flux2Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_cache: "Flux2KVCache | None" = None
|
||||
|
||||
|
||||
class Flux2KVLayerCache:
|
||||
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
|
||||
|
||||
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
|
||||
format: (batch_size, num_ref_tokens, num_heads, head_dim).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_ref: torch.Tensor | None = None
|
||||
self.v_ref: torch.Tensor | None = None
|
||||
|
||||
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
|
||||
"""Store reference token K/V."""
|
||||
self.k_ref = k_ref
|
||||
self.v_ref = v_ref
|
||||
|
||||
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Retrieve cached reference token K/V."""
|
||||
if self.k_ref is None:
|
||||
raise RuntimeError("KV cache has not been populated yet.")
|
||||
return self.k_ref, self.v_ref
|
||||
|
||||
def clear(self):
|
||||
self.k_ref = None
|
||||
self.v_ref = None
|
||||
|
||||
|
||||
class Flux2KVCache:
|
||||
"""Container for all layers' reference-token KV caches.
|
||||
|
||||
Holds separate cache lists for double-stream and single-stream transformer blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_double_layers: int, num_single_layers: int):
|
||||
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
|
||||
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
|
||||
self.num_ref_tokens: int = 0
|
||||
|
||||
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.double_block_caches[layer_idx]
|
||||
|
||||
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.single_block_caches[layer_idx]
|
||||
|
||||
def clear(self):
|
||||
for cache in self.double_block_caches:
|
||||
cache.clear()
|
||||
for cache in self.single_block_caches:
|
||||
cache.clear()
|
||||
self.num_ref_tokens = 0
|
||||
|
||||
|
||||
def _flux2_kv_causal_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_txt_tokens: int,
|
||||
num_ref_tokens: int,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
backend=None,
|
||||
) -> torch.Tensor:
|
||||
"""Causal attention for KV caching where reference tokens only self-attend.
|
||||
|
||||
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
|
||||
|
||||
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
|
||||
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
|
||||
between txt and img.
|
||||
"""
|
||||
# No ref tokens and no cache — standard full attention
|
||||
if num_ref_tokens == 0 and kv_cache is None:
|
||||
return dispatch_attention_fn(query, key, value, backend=backend)
|
||||
|
||||
if kv_cache is not None:
|
||||
# Cached mode: inject ref K/V between txt and img
|
||||
k_ref, v_ref = kv_cache.get()
|
||||
|
||||
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
|
||||
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
|
||||
|
||||
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
|
||||
|
||||
# Extract mode: ref tokens self-attend, txt+img attend to all
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
|
||||
q_txt = query[:, :ref_start]
|
||||
q_ref = query[:, ref_start:ref_end]
|
||||
q_img = query[:, ref_end:]
|
||||
|
||||
k_txt = key[:, :ref_start]
|
||||
k_ref = key[:, ref_start:ref_end]
|
||||
k_img = key[:, ref_end:]
|
||||
|
||||
v_txt = value[:, :ref_start]
|
||||
v_ref = value[:, ref_start:ref_end]
|
||||
v_img = value[:, ref_end:]
|
||||
|
||||
# txt+img attend to all tokens
|
||||
q_txt_img = torch.cat([q_txt, q_img], dim=1)
|
||||
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
|
||||
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
|
||||
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
|
||||
attn_txt = attn_txt_img[:, :ref_start]
|
||||
attn_img = attn_txt_img[:, ref_start:]
|
||||
|
||||
# ref tokens self-attend only
|
||||
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
|
||||
|
||||
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
|
||||
|
||||
|
||||
def _blend_mod_params(
|
||||
img_params: tuple[torch.Tensor, ...],
|
||||
ref_params: tuple[torch.Tensor, ...],
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return tuple(blended)
|
||||
|
||||
|
||||
def _blend_double_block_mods(
|
||||
img_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
|
||||
with `Flux2Modulation.split(mod, 2)`.
|
||||
"""
|
||||
if img_mod.ndim == 2:
|
||||
img_mod = img_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_chunks = torch.chunk(img_mod, 6, dim=-1)
|
||||
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
|
||||
img_mods = (img_chunks[0:3], img_chunks[3:6])
|
||||
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
|
||||
|
||||
all_params = []
|
||||
for img_set, ref_set in zip(img_mods, ref_mods):
|
||||
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
|
||||
all_params.extend(blended)
|
||||
return torch.cat(all_params, dim=-1)
|
||||
|
||||
|
||||
def _blend_single_block_mods(
|
||||
single_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_txt: int,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
|
||||
"""
|
||||
if single_mod.ndim == 2:
|
||||
single_mod = single_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_params = torch.chunk(single_mod, 3, dim=-1)
|
||||
ref_params = torch.chunk(ref_mod, 3, dim=-1)
|
||||
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
im_expanded = im.expand(B, seq_len, -1)
|
||||
rm_expanded = rm.expand(B, num_ref, -1)
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return torch.cat(blended, dim=-1)
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -391,108 +181,9 @@ class Flux2AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
|
||||
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
|
||||
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -621,90 +312,6 @@ class Flux2ParallelSelfAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVParallelSelfAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
|
||||
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
|
||||
identically to `Flux2ParallelSelfAttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_txt_tokens: int = 0,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
attn_output = attn_output.flatten(2, 3)
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
@@ -715,7 +322,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
@@ -1173,8 +780,6 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
_skip_keys = ["kv_cache"]
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
@@ -1186,21 +791,19 @@ class Flux2Transformer2DModel(
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
kv_cache: "Flux2KVCache | None" = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
ref_fixed_timestep: float = 0.0,
|
||||
) -> torch.Tensor | Flux2Transformer2DModelOutput:
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`Flux2Transformer2DModel`] forward method.
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep (`torch.LongTensor`):
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -1208,23 +811,13 @@ class Flux2Transformer2DModel(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
|
||||
returned. When "cached", the provided cache is used to inject ref K/V during attention.
|
||||
kv_cache_mode (`str`, *optional*):
|
||||
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
|
||||
`None`, standard forward pass without KV caching.
|
||||
num_ref_tokens (`int`, defaults to `0`):
|
||||
Number of reference image tokens prepended to `hidden_states` (only used when
|
||||
`kv_cache_mode="extract"`).
|
||||
ref_fixed_timestep (`float`, defaults to `0.0`):
|
||||
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
|
||||
populated `Flux2KVCache`.
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
@@ -1239,33 +832,13 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
|
||||
# KV extract mode: create cache and blend modulations for ref tokens
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
num_img_tokens = hidden_states.shape[1] # includes ref tokens
|
||||
|
||||
kv_cache = Flux2KVCache(
|
||||
num_double_layers=len(self.transformer_blocks),
|
||||
num_single_layers=len(self.single_transformer_blocks),
|
||||
)
|
||||
kv_cache.num_ref_tokens = num_ref_tokens
|
||||
|
||||
# Ref tokens use a fixed timestep for modulation
|
||||
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
|
||||
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
|
||||
|
||||
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
|
||||
ref_single_mod = self.single_stream_modulation(ref_temb)
|
||||
|
||||
# Blend double block img modulation: [ref_mod, img_mod]
|
||||
double_stream_mod_img = _blend_double_block_mods(
|
||||
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
|
||||
)
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
@@ -1278,29 +851,8 @@ class Flux2Transformer2DModel(
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Build joint_attention_kwargs with KV cache info
|
||||
if kv_cache_mode == "extract":
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "extract",
|
||||
"num_ref_tokens": num_ref_tokens,
|
||||
}
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "cached",
|
||||
"num_ref_tokens": kv_cache.num_ref_tokens,
|
||||
}
|
||||
else:
|
||||
kv_attn_kwargs = joint_attention_kwargs
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
# 4. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1309,7 +861,7 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
kv_attn_kwargs,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
@@ -1318,30 +870,13 @@ class Flux2Transformer2DModel(
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=kv_attn_kwargs,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
total_single_len = hidden_states.shape[1]
|
||||
single_stream_mod = _blend_single_block_mods(
|
||||
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
|
||||
)
|
||||
|
||||
# Build single-block KV kwargs (single blocks need num_txt_tokens)
|
||||
if kv_cache_mode is not None:
|
||||
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
|
||||
else:
|
||||
kv_attn_kwargs_single = kv_attn_kwargs
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
# 5. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1349,7 +884,7 @@ class Flux2Transformer2DModel(
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
kv_attn_kwargs_single,
|
||||
joint_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -1357,25 +892,16 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states=None,
|
||||
temb_mod=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=kv_attn_kwargs_single,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# Remove text tokens (and ref tokens in extract mode) from concatenated stream
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
|
||||
else:
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 7. Output layers
|
||||
# 6. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if kv_cache_mode == "extract":
|
||||
if not return_dict:
|
||||
return (output, kv_cache)
|
||||
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Flux2Transformer2DModelOutput(sample=output)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
|
||||
@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
||||
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -34,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
from .pipeline_flux2_klein import Flux2KleinPipeline
|
||||
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,886 +0,0 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import Flux2KleinKVPipeline
|
||||
|
||||
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> ref_image = Image.open("reference.png")
|
||||
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
|
||||
>>> image.save("flux2_kv_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
|
||||
|
||||
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
|
||||
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
|
||||
inference when using reference images.
|
||||
|
||||
Reference:
|
||||
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3ForCausalLM`]):
|
||||
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
||||
tokenizer (`Qwen2TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
is_distilled: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Set KV-cache-aware attention processors
|
||||
self._set_kv_attn_processors()
|
||||
|
||||
@staticmethod
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: str | list[str],
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: list[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: torch.Tensor | None = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
||||
def _prepare_image_ids(
|
||||
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (list[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def _set_kv_attn_processors(self):
|
||||
"""Replace default attention processors with KV-cache-aware variants."""
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in self.transformer.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: torch.Tensor | None = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
|
||||
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
|
||||
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
|
||||
recomputing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 4):
|
||||
The number of denoising steps.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
Generator(s) for deterministic generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: `"pil"` or `"np"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Flux2PipelineOutput` or a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs passed to attention processors.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Callback function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
Tensor inputs for the callback function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the prompt.
|
||||
text_encoder_out_layers (`tuple[int]`):
|
||||
Layer indices for text encoder hidden state extraction.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop with KV caching
|
||||
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
|
||||
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
|
||||
# No ref images: standard forward
|
||||
self.scheduler.set_begin_index(0)
|
||||
kv_cache = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if i == 0 and image_latents is not None:
|
||||
# Step 0: include ref tokens, extract KV cache
|
||||
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
|
||||
|
||||
noise_pred, kv_cache = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=image_latents.shape[1],
|
||||
)
|
||||
|
||||
elif kv_cache is not None:
|
||||
# Steps 1+: use cached ref KV, no ref tokens in input
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# No reference images: standard forward
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Clean up KV cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.clear()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
@@ -1202,21 +1202,6 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -32,6 +33,33 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2KleinKVPipeline
|
||||
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256,
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
guidance_embeds=False,
|
||||
)
|
||||
|
||||
# Create minimal Qwen3 config
|
||||
config = Qwen3Config(
|
||||
intermediate_size=16,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Qwen3ForCausalLM(config)
|
||||
|
||||
# Use a simple tokenizer for testing
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"image": Image.new("RGB", (64, 64)),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 64,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_without_image(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["image"]
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image.shape, (1, 8, 8, 3))
|
||||
|
||||
@unittest.skip("Needs to be revisited")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user