Compare commits

...

11 Commits

Author SHA1 Message Date
DN6
da34261cc2 update 2025-10-08 20:24:54 +05:30
DN6
14cfbab078 update 2025-10-08 20:24:17 +05:30
DN6
a69b1e06fc Merge branch 'main' into attn-refactor-blocks 2025-10-08 14:40:26 +05:30
DN6
77c4e0932c Merge branch 'main' into attn-refactor-blocks 2025-10-03 11:39:22 +05:30
DN6
fc322ed052 update 2025-10-03 11:38:16 +05:30
DN6
fed2c46482 update 2025-10-03 11:36:21 +05:30
DN6
66320f031a update 2025-10-03 11:35:39 +05:30
DN6
86a1290e51 update 2025-08-26 10:44:43 +05:30
DN6
57f374b87b update 2025-08-25 14:05:29 +05:30
DN6
3b2e85d853 update 2025-08-25 13:23:40 +05:30
DN6
12b4edc2fe update 2025-08-25 11:15:24 +05:30
47 changed files with 2378 additions and 964 deletions

View File

@@ -151,8 +151,8 @@ def _register_attention_processors_metadata():
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_2d import BasicTransformerBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock

View File

@@ -21,10 +21,8 @@ import torch.nn.functional as F
from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .attention_processor import Attention, AttentionProcessor
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
if is_xformers_available():
@@ -505,19 +503,16 @@ class AttentionModuleMixin:
return encoder_hidden_states
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
def _chunked_feed_forward(*args, **kwargs):
deprecate(
"_chunked_feed_forward",
"1.0.0",
"Importing `_chunked_feed_forward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import _chunked_feed_forward` instead.",
standard_warn=False,
)
return ff_output
from .transformers.modeling_common import _chunked_feed_forward
return _chunked_feed_forward(*args, **kwargs)
@maybe_allow_in_graph
@@ -577,161 +572,16 @@ class JointTransformerBlock(nn.Module):
processing of `context` conditions.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
def __new__(cls, *args, **kwargs):
deprecate(
"JointTransformerBlock",
"1.0.0",
"Importing `JointTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_sd3 import SD3TransformerBlock` instead.",
standard_warn=False,
)
from .transformers.transformer_sd3 import SD3TransformerBlock
if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
else:
self.attn2 = None
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
return SD3TransformerBlock(*args, **kwargs)
@maybe_allow_in_graph
@@ -770,300 +620,16 @@ class BasicTransformerBlock(nn.Module):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
def __new__(cls, *args, **kwargs):
deprecate(
"BasicTransformerBlock",
"1.0.0",
"Importing `BasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_2d import BasicTransformerBlock` instead.",
standard_warn=False,
)
from .transformers.transformer_2d import BasicTransformerBlock
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single": # For Latte
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
return BasicTransformerBlock(*args, **kwargs)
class LuminaFeedForward(nn.Module):
@@ -1081,38 +647,16 @@ class LuminaFeedForward(nn.Module):
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
def __new__(cls, *args, **kwargs):
deprecate(
"LuminaFeedForward",
"1.0.0",
"Importing `LuminaFeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_lumina2 import LuminaFeedForward` instead.",
standard_warn=False,
)
from .transformers.transformer_lumina2 import LuminaFeedForward
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.silu = FP32SiLU()
def forward(self, x):
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
return LuminaFeedForward(*args, **kwargs)
@maybe_allow_in_graph
@@ -1128,193 +672,29 @@ class TemporalBasicTransformerBlock(nn.Module):
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
def __new__(cls, *args, **kwargs):
deprecate(
"TemporalBasicTransformerBlock",
"1.0.0",
"Importing `TemporalBasicTransformerBlock` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.transformer_temporal import TemporalBasicTransformerBlock` instead.",
standard_warn=False,
)
from .transformers.transformer_temporal import TemporalBasicTransformerBlock
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
return TemporalBasicTransformerBlock(*args, **kwargs)
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
def __new__(cls, *args, **kwargs):
deprecate(
"SkipFFTransformerBlock",
"1.0.0",
"Importing `SkipFFTransformerBlock` from `diffusers.models.attention` is deprecated and this will be removed in a future version. Please use `from diffusers.models.unets.uvit_2d import SkipFFTransformerBlock` instead.",
standard_warn=False,
)
from .unets.uvit_2d import SkipFFTransformerBlock
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
return SkipFFTransformerBlock(*args, **kwargs)
@maybe_allow_in_graph
@@ -1679,50 +1059,13 @@ class FeedForward(nn.Module):
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
def __new__(cls, *args, **kwargs):
deprecate(
"FeedForward",
"1.0.0",
"Importing `FeedForward` from `diffusers.models.attention` is deprecated. Please use `from diffusers.models.transformers.modeling_common import FeedForward` instead.",
standard_warn=False,
)
from .transformers.modeling_common import FeedForward
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
return FeedForward(*args, **kwargs)

View File

@@ -24,8 +24,8 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, u
from ..attention_processor import AttentionProcessor
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock

View File

@@ -24,8 +24,8 @@ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, u
from ..attention_processor import AttentionProcessor
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_qwenimage import (
QwenEmbedRope,
QwenImageTransformerBlock,

View File

@@ -23,9 +23,9 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.sana_transformer import SanaTransformerBlock
from .controlnet import zero_module

View File

@@ -22,12 +22,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import JointTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
from ..transformers.modeling_common import Transformer2DModelOutput
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock, SD3TransformerBlock
from .controlnet import BaseOutput, zero_module
@@ -132,7 +131,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
# It needs to crafted when we get the actual checkpoints.
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
SD3TransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,

View File

@@ -1530,7 +1530,7 @@ class ImageProjection(nn.Module):
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
self.norm = nn.LayerNorm(cross_attention_dim)
@@ -1542,7 +1542,7 @@ class IPAdapterFullImageProjection(nn.Module):
class IPAdapterFaceIDImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward
self.num_tokens = num_tokens
self.cross_attention_dim = cross_attention_dim
@@ -2219,7 +2219,7 @@ class IPAdapterPlusImageProjectionBlock(nn.Module):
ffn_ratio: float = 4,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward
self.ln0 = nn.LayerNorm(embed_dims)
self.ln1 = nn.LayerNorm(embed_dims)
@@ -2334,7 +2334,7 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
ffproj_ratio: int = 2,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward
self.num_tokens = num_tokens
self.embed_dim = embed_dims
@@ -2404,7 +2404,7 @@ class IPAdapterTimeImageProjectionBlock(nn.Module):
ffn_ratio: int = 4,
) -> None:
super().__init__()
from .attention import FeedForward
from .transformers.modeling_common import FeedForward
self.ln0 = nn.LayerNorm(hidden_dim)
self.ln1 = nn.LayerNorm(hidden_dim)

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
@dataclass
@@ -17,8 +17,7 @@ class AutoencoderKLOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution" # noqa: F821
@dataclass
class Transformer2DModelOutput(BaseOutput):
class Transformer2DModelOutput:
"""
The output of [`Transformer2DModel`].
@@ -28,4 +27,13 @@ class Transformer2DModelOutput(BaseOutput):
distributions for the unnoised latent pixels.
"""
sample: "torch.Tensor" # noqa: F821
def __new__(cls, *args, **kwargs):
deprecate(
"Transformer2DModelOutput",
"1.0.0",
"Importing `Transformer2DModelOutput` from `diffusers.models.modeling_outputs` is deprecated. Please use `from diffusers.models.transformers.modeling_common import Transformer2DModelOutput` instead.",
standard_warn=False,
)
from .transformers.modeling_common import Transformer2DModelOutput
return Transformer2DModelOutput(*args, **kwargs)

View File

@@ -30,9 +30,9 @@ from ..attention_processor import (
FusedAuraFlowAttnProcessor2_0,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormZero, FP32LayerNorm
from .modeling_common import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -194,7 +194,8 @@ class AuraFlowSingleTransformerBlock(nn.Module):
@maybe_allow_in_graph
class AuraFlowJointTransformerBlock(nn.Module):
# Copied from diffusers.models.transformers.transformer_sd3.SD3TransformerBlock with SD3->AuraFlow
class AuraFlowTransformerBlock(nn.Module):
r"""
Transformer block for Aura Flow. Similar to SD3 MMDiT. Differences (non-exhaustive):
@@ -337,7 +338,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.joint_transformer_blocks = nn.ModuleList(
[
AuraFlowJointTransformerBlock(
AuraFlowTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,

View File

@@ -22,13 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -22,12 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -19,15 +19,348 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import BasicTransformerBlock
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..attention import Attention, GatedSelfAttentionDense
from ..embeddings import PatchEmbed, SinusoidalPositionalEmbedding
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
class DiTTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single":
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model as introduced in DiT (https://huggingface.co/papers/2212.09748).
@@ -121,7 +454,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
DiTTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,

View File

@@ -15,7 +15,7 @@ from typing import Optional
from torch import nn
from ..modeling_outputs import Transformer2DModelOutput
from .modeling_common import Transformer2DModelOutput
from .transformer_2d import Transformer2DModel

View File

@@ -19,16 +19,15 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
from ..embeddings import (
HunyuanCombinedTimestepTextSizeStyleEmbedding,
PatchEmbed,
PixArtAlphaTextProjection,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -12,18 +12,359 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, Optional
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ..attention import BasicTransformerBlock
from ...utils import logging
from ..attention import Attention, GatedSelfAttentionDense
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..embeddings import (
PatchEmbed,
PixArtAlphaTextProjection,
SinusoidalPositionalEmbedding,
get_1d_sincos_pos_embed_from_grid,
)
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
class LatteTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single":
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
@@ -110,7 +451,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# 2. Define spatial transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
LatteTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
@@ -130,7 +471,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
# 3. Define temporal transformers blocks
self.temporal_transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
LatteTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,

View File

@@ -19,15 +19,15 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention, LuminaAttnProcessor2_0
from ..embeddings import (
LuminaCombinedTimestepCaptionEmbedding,
LuminaPatchEmbed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
from .modeling_common import Transformer2DModelOutput
from .transformer_lumina2 import LuminaFeedForward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -0,0 +1,114 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from ...utils import BaseOutput, deprecate
from ..activations import GEGLU, GELU, ApproximateGELU, LinearActivation, SwiGLU
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: "torch.Tensor" # noqa: F821
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
elif activation_fn == "linear-silu":
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states

View File

@@ -18,17 +18,349 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import BasicTransformerBlock
from ..attention import GatedSelfAttentionDense
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
class PixArtTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single":
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
r"""
A 2D Transformer model as introduced in PixArt family of models (https://huggingface.co/papers/2310.00426,
@@ -151,7 +483,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
PixArtTransformerBlock(
self.inner_dim,
self.config.num_attention_heads,
self.config.attention_head_dim,

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn.functional as F
@@ -7,8 +7,8 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput
from ..attention import BasicTransformerBlock
from ...utils import BaseOutput, logging
from ..attention import Attention, GatedSelfAttentionDense
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -16,8 +16,345 @@ from ..attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..embeddings import SinusoidalPositionalEmbedding, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero
from .modeling_common import FeedForward, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.models.transformers.transformer_2d.BasicTransformerBlock
class PriorTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single":
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
@dataclass
@@ -133,7 +470,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
PriorTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,

View File

@@ -27,9 +27,9 @@ from ..attention_processor import (
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from .modeling_common import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -23,10 +23,10 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
from ..modeling_utils import ModelMixin
from ..transformers.transformer_2d import Transformer2DModelOutput
from .modeling_common import FeedForward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -19,16 +19,352 @@ from torch import nn
from ...configuration_utils import LegacyConfigMixin, register_to_config
from ...utils import deprecate, logging
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..attention import Attention, GatedSelfAttentionDense
from ..embeddings import (
ImagePositionalEmbeddings,
PatchEmbed,
PixArtAlphaTextProjection,
SinusoidalPositionalEmbedding,
)
from ..modeling_utils import LegacyModelMixin
from ..normalization import AdaLayerNormSingle
from ..normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormSingle, AdaLayerNormZero
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
ada_norm_bias: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if norm_type == "ada_norm":
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_zero":
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm1 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
if norm_type == "ada_norm":
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif norm_type == "ada_norm_continuous":
self.norm2 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"rms_norm",
)
else:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
else:
if norm_type == "ada_norm_single":
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if norm_type == "ada_norm_continuous":
self.norm3 = AdaLayerNormContinuous(
dim,
ada_norm_continous_conditioning_embedding_dim,
norm_elementwise_affine,
norm_eps,
ada_norm_bias,
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if norm_type == "ada_norm_single":
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.norm_type == "ada_norm_zero":
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm1(hidden_states)
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.norm_type == "ada_norm_zero":
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.norm_type == "ada_norm_single":
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "ada_norm":
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
norm_hidden_states = self.norm2(hidden_states)
elif self.norm_type == "ada_norm_single":
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
elif self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
# i2vgen doesn't have this norm 🤷‍♂️
if self.norm_type == "ada_norm_continuous":
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif not self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm3(hidden_states)
if self.norm_type == "ada_norm_zero":
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.norm_type == "ada_norm_single":
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.norm_type == "ada_norm_single":
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Transformer2DModelOutput(Transformer2DModelOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.modeling_outputs import Transformer2DModelOutput`, instead."

View File

@@ -22,13 +22,12 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__)

View File

@@ -10,13 +10,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention import AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, apply_rotary_emb, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -24,12 +24,12 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
from .transformer_flux import FluxAttention, FluxAttnProcessor

View File

@@ -20,12 +20,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -22,13 +22,12 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LayerNorm, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -22,12 +22,11 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
if is_torchvision_available():

View File

@@ -22,11 +22,11 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -25,7 +25,7 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
@@ -34,9 +34,9 @@ from ..embeddings import (
apply_rotary_emb,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -6,8 +6,8 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.modeling_common import Transformer2DModelOutput
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention

View File

@@ -23,7 +23,6 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
@@ -33,9 +32,9 @@ from ..embeddings import (
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -23,9 +23,9 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from .modeling_common import Transformer2DModelOutput
from .transformer_hunyuan_video import (
HunyuanVideoConditionEmbedding,
HunyuanVideoPatchEmbed,

View File

@@ -25,13 +25,13 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -23,17 +23,66 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import LuminaFeedForward
from ..activations import FP32SiLU
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
from .modeling_common import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LuminaFeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.silu = FP32SiLU()
def forward(self, x):
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,

View File

@@ -23,13 +23,12 @@ from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..cache_utils import CacheMixin
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -23,9 +23,9 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, RMSNorm
from .modeling_common import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -26,14 +26,14 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention import AttentionMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, RMSNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -20,7 +20,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, JointTransformerBlock
from ..attention_processor import (
Attention,
AttentionProcessor,
@@ -28,14 +27,184 @@ from ..attention_processor import (
JointAttnProcessor2_0,
)
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX
from .modeling_common import FeedForward, Transformer2DModelOutput, _chunked_feed_forward
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SD3TransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://huggingface.co/papers/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__()
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
)
elif context_norm_type == "ada_norm_zero":
self.norm1_context = AdaLayerNormZero(dim)
else:
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=1e-6,
)
else:
self.attn2 = None
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
if not context_pre_only:
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
else:
self.norm2_context = None
self.ff_context = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
else:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
else:
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
if self.context_pre_only:
encoder_hidden_states = None
else:
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
context_ff_output = _chunked_feed_forward(
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
)
else:
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
class SD3SingleTransformerBlock(nn.Module):
def __init__(
@@ -155,7 +324,7 @@ class SD3Transformer2DModel(
self.transformer_blocks = nn.ModuleList(
[
JointTransformerBlock(
SD3TransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,

View File

@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import (
@@ -32,9 +32,9 @@ from ..embeddings import (
get_1d_rotary_pos_embed,
get_1d_sincos_pos_embed_from_grid,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin, get_parameter_dtype
from ..normalization import FP32LayerNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -19,10 +19,13 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..resnet import AlphaBlender
from .modeling_common import FeedForward, _chunked_feed_forward
from .transformer_2d import BasicTransformerBlock
@dataclass
@@ -38,6 +41,136 @@ class TransformerTemporalModelOutput(BaseOutput):
sample: torch.Tensor
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.

View File

@@ -24,13 +24,13 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -21,11 +21,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, FeedForward
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from .modeling_common import FeedForward, Transformer2DModelOutput
from .transformer_wan import (
WanAttention,
WanAttnProcessor,

View File

@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging
from ..activations import get_activation
from ..attention import Attention, FeedForward
from ..attention import Attention
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -33,6 +33,7 @@ from ..attention_processor import (
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformers.modeling_common import FeedForward
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
UNetMidBlock3DCrossAttn,

View File

@@ -24,7 +24,6 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import BasicTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -41,7 +40,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
from ..transformers.transformer_2d import BasicTransformerBlock, Transformer2DModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel

View File

@@ -22,10 +22,10 @@ from torch.utils.checkpoint import checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -34,6 +34,79 @@ from ..embeddings import TimestepEmbedding, get_timestep_embedding
from ..modeling_utils import ModelMixin
from ..normalization import GlobalResponseNorm, RMSNorm
from ..resnet import Downsample2D, Upsample2D
from ..transformers.transformer_2d import BasicTransformerBlock
class SkipFFTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
kv_input_dim: int,
kv_input_dim_proj_use_bias: bool,
dropout=0.0,
cross_attention_dim: int = None,
attention_bias: bool = False,
attention_out_bias: bool = True,
):
super().__init__()
if kv_input_dim != dim:
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
else:
self.kv_mapper = None
self.norm1 = RMSNorm(dim, 1e-06)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim,
out_bias=attention_out_bias,
)
self.norm2 = RMSNorm(dim, 1e-06)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
out_bias=attention_out_bias,
)
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
if self.kv_mapper is not None:
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
return hidden_states
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):

View File

@@ -1,147 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXPipeline,
LTXVideoTransformer3DModel,
)
from ..testing_utils import floats_tensor, require_peft_backend
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
transformer_kwargs = {
"in_channels": 8,
"out_channels": 8,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 4,
"attention_head_dim": 8,
"cross_attention_dim": 32,
"num_layers": 1,
"caption_channels": 32,
}
transformer_cls = LTXVideoTransformer3DModel
vae_kwargs = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (8, 8, 8, 8),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, False, False),
"decoder_spatio_temporal_scaling": (True, True, False, False),
"decoder_inject_noise": (False, False, False, False, False),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
}
vae_cls = AutoencoderKLLTXVideo
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
text_encoder_target_modules = ["q", "k", "v", "o"]
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 16
num_channels = 8
num_frames = 9
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
latent_height = 8
latent_width = 8
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_latent_frames, num_channels, latent_height, latent_width))
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "dance monkey",
"num_frames": num_frames,
"num_inference_steps": 4,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"max_sequence_length": sequence_length,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass