mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
19 Commits
v0.35.1-pa
...
add-attn-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad4e8be19a | ||
|
|
be84828840 | ||
|
|
a7462302cd | ||
|
|
9b6a062adf | ||
|
|
0ed58e4ec6 | ||
|
|
200e4ac462 | ||
|
|
94ae28edea | ||
|
|
37de8e790c | ||
|
|
1b4067c0d1 | ||
|
|
b68673d8a7 | ||
|
|
747f7f6d07 | ||
|
|
a923a73a17 | ||
|
|
b67fcf2221 | ||
|
|
93d8799604 | ||
|
|
858d7c5cd5 | ||
|
|
6f2883396d | ||
|
|
36c68ad6bd | ||
|
|
8249ac9100 | ||
|
|
cebbe8960a |
2
setup.py
2
setup.py
@@ -136,7 +136,7 @@ _deps = [
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"tiktoken>=0.7.0",
|
||||
"torch>=1.4",
|
||||
"torch>=2.0",
|
||||
"torchvision",
|
||||
"transformers>=4.41.2",
|
||||
"urllib3<=2.0.0",
|
||||
|
||||
@@ -43,7 +43,7 @@ deps = {
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"tiktoken": "tiktoken>=0.7.0",
|
||||
"torch": "torch>=1.4",
|
||||
"torch": "torch>=2.0",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.41.2",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
|
||||
@@ -27,6 +27,7 @@ _import_structure = {}
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
_import_structure["attention_modules"] = ["FluxAttention", "SanaAttention", "SD3Attention"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
@@ -106,6 +107,7 @@ if is_flax_available():
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_modules import FluxAttention, SanaAttention, SD3Attention
|
||||
from .auto_model import AutoModel
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -12,13 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from huggingface_hub.utils import EntryNotFoundError, validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
|
||||
class AutoModel(ConfigMixin):
|
||||
@@ -153,17 +153,39 @@ class AutoModel(ConfigMixin):
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
"subfolder": subfolder,
|
||||
}
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
library = None
|
||||
orig_class_name = None
|
||||
from diffusers import pipelines
|
||||
|
||||
library = importlib.import_module("diffusers")
|
||||
# Always attempt to fetch model_index.json first
|
||||
try:
|
||||
cls.config_name = "model_index.json"
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
|
||||
model_cls = getattr(library, orig_class_name, None)
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
if subfolder is not None and subfolder in config:
|
||||
library, orig_class_name = config[subfolder]
|
||||
|
||||
except (OSError, EntryNotFoundError) as e:
|
||||
logger.debug(e)
|
||||
|
||||
# Unable to load from model_index.json so fallback to loading from config
|
||||
if library is None and orig_class_name is None:
|
||||
cls.config_name = "config.json"
|
||||
load_config_kwargs.update({"subfolder": subfolder})
|
||||
|
||||
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
||||
orig_class_name = config["_class_name"]
|
||||
library = "diffusers"
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=orig_class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=hasattr(pipelines, library),
|
||||
)
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
@@ -62,6 +62,98 @@ class ResBlock(nn.Module):
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaMultiscaleLinearAttention(nn.Module):
|
||||
r"""Lightweight multi-scale linear attention"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import
|
||||
from ..normalization import get_normalization
|
||||
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
self.norm_out = get_normalization(norm_type, num_features=out_channels)
|
||||
|
||||
self.processor = SanaMultiscaleAttnProcessorSDPA()
|
||||
|
||||
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
|
||||
scores = torch.matmul(value, key.transpose(-1, -2))
|
||||
hidden_states = torch.matmul(scores, query)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
|
||||
return hidden_states
|
||||
|
||||
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states)
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -21,7 +21,8 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import Attention, SpatialNorm
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import SpatialNorm
|
||||
from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from ..downsampling import Downsample2D
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
|
||||
@@ -24,7 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import Attention
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
@@ -23,7 +23,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import MochiVaeAttnProcessor2_0
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
|
||||
|
||||
@@ -22,11 +22,12 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import JointTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.modeling_common import JointTransformerBlock
|
||||
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from .activations import FP32SiLU, get_activation
|
||||
from .attention_processor import Attention
|
||||
from .attention import Attention
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
|
||||
@@ -23,11 +23,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnProcessorMixin,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -135,6 +134,98 @@ class AuraFlowPreFinalBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class AuraFlowAttnProcessorSDPA(AttnProcessorMixin):
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"AuraFlowAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_projections is not None:
|
||||
encoder_hidden_states_query_proj, encoder_hidden_states_key_proj, encoder_hidden_states_value_proj = (
|
||||
encoder_projections
|
||||
)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_projections is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = self.attention_fn(query, key, value, scale=attn.scale)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttention(Attention):
|
||||
default_processor_cls = AuraFlowAttnProcessorSDPA
|
||||
_available_processors = [AuraFlowAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
"""Similar to `AuraFlowJointTransformerBlock` with a single DiT instead of an MMDiT."""
|
||||
@@ -145,7 +236,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
self.attn = AuraFlowAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
@@ -208,7 +299,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
self.norm1_context = AdaLayerNormZero(dim, bias=False, norm_type="fp32_layer_norm")
|
||||
|
||||
processor = AuraFlowAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
self.attn = AuraFlowAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
@@ -267,7 +358,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
|
||||
|
||||
@@ -357,105 +448,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedAuraFlowAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAuraFlowAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -22,18 +22,92 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_processor import Attention, AttnProcessorMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogVideoXAttnProcessorSDPA(AttnProcessorMixin):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"CogVideoXAttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query, key, value, _ = self.get_projections(attn, hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = self.attention_fn(query, key, value, attention_mask=attention_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttention(Attention):
|
||||
default_processor_cls = CogVideoXAttnProcessorSDPA
|
||||
_available_processors = [CogVideoXAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class CogVideoXBlock(nn.Module):
|
||||
r"""
|
||||
@@ -92,7 +166,7 @@ class CogVideoXBlock(nn.Module):
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = CogVideoXAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
@@ -100,7 +174,6 @@ class CogVideoXBlock(nn.Module):
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
@@ -157,7 +230,7 @@ class CogVideoXBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
@@ -331,105 +404,9 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -22,12 +22,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0, CogVideoXAttnProcessorSDPA
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -228,6 +229,11 @@ class PerceiverCrossAttention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class ConsisIDAttention(Attention):
|
||||
default_processor_cls = CogVideoXAttnProcessorSDPA
|
||||
_available_processors = [CogVideoXAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ConsisIDBlock(nn.Module):
|
||||
r"""
|
||||
@@ -286,7 +292,7 @@ class ConsisIDBlock(nn.Module):
|
||||
# 1. Self Attention
|
||||
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = ConsisIDAttention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
@@ -348,7 +354,7 @@ class ConsisIDBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data in [ConsisID](https://github.com/PKU-YuanGroup/ConsisID).
|
||||
|
||||
@@ -620,65 +626,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -19,16 +19,17 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class DiTTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in DiT (https://arxiv.org/abs/2212.09748).
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -19,8 +19,8 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0
|
||||
from ..attention import Attention, AttentionMixin, AttnProcessorMixin
|
||||
from ..attention_processor import HunyuanAttnProcessor2_0, HunyuanAttnProcessorSDPA
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
PatchEmbed,
|
||||
@@ -29,6 +29,7 @@ from ..embeddings import (
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -55,6 +56,98 @@ class AdaLayerNormShift(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanAttnProcessorSDPA(AttnProcessorMixin):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the HunyuanDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: AttentionModuleMixin,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query, key, value = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = self.attention_fn(query, key, value, attn_mask=attention_mask)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTAttention(Attention):
|
||||
default_processor_cls = HunyuanAttnProcessorSDPA
|
||||
_available_processors = [HunyuanAttnProcessorSDPA]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HunyuanDiTBlock(nn.Module):
|
||||
r"""
|
||||
@@ -110,7 +203,7 @@ class HunyuanDiTBlock(nn.Module):
|
||||
# 1. Self-Attn
|
||||
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = HunyuanDiTAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -118,7 +211,6 @@ class HunyuanDiTBlock(nn.Module):
|
||||
qk_norm="layer_norm" if qk_norm else None,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
processor=HunyuanAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross-Attn
|
||||
@@ -200,7 +292,7 @@ class HunyuanDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
class HunyuanDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
@@ -318,105 +410,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedHunyuanAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
|
||||
@@ -19,15 +19,16 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
|
||||
class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin, AttentionMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
"""
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention, LuminaAttnProcessor2_0
|
||||
from ..attention_processor import Attention, AttentionMixin, AttnProcessorMixin
|
||||
from ..embeddings import (
|
||||
LuminaCombinedTimestepCaptionEmbedding,
|
||||
LuminaPatchEmbed,
|
||||
@@ -33,6 +33,101 @@ from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNor
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LuminaAttnProcessorSDPA(AttnProcessorMixin):
|
||||
compatible_backends = ["cuda", "cpu", "xpu"]
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessorSDPA requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[torch.Tensor] = None,
|
||||
key_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query, key, value, _ = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
|
||||
if key_rotary_emb is not None:
|
||||
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if key_rotary_emb is None:
|
||||
softmax_scale = None
|
||||
else:
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).to(dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaNextAttention(Attention):
|
||||
default_processor_cls = LuminaAttnProcessorSDPA
|
||||
_available_processors = [LuminaAttnProcessorSDPA]
|
||||
|
||||
|
||||
class LuminaNextDiTBlock(nn.Module):
|
||||
"""
|
||||
A LuminaNextDiTBlock for LuminaNextDiT2DModel.
|
||||
@@ -68,7 +163,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
self.gate = nn.Parameter(torch.zeros([num_attention_heads]))
|
||||
|
||||
# Self-attention
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = LuminaNextAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -78,12 +173,11 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
self.attn1.to_out = nn.Identity()
|
||||
|
||||
# Cross-attention
|
||||
self.attn2 = Attention(
|
||||
self.attn2 = LuminaNextAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
dim_head=dim // num_attention_heads,
|
||||
@@ -93,7 +187,6 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=LuminaAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
self.feed_forward = LuminaFeedForward(
|
||||
@@ -175,7 +268,7 @@ class LuminaNextDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
|
||||
class LuminaNextDiT2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
LuminaNextDiT: Diffusion model with a Transformer backbone.
|
||||
|
||||
|
||||
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,25 +11,26 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import AttnProcessor
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class PixArtTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
|
||||
https://arxiv.org/abs/2403.04692).
|
||||
@@ -184,65 +185,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
in_features=self.config.caption_channels, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using inherited method from AttentionMixin
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
@@ -252,45 +195,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -8,16 +8,16 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -33,7 +33,7 @@ class PriorTransformerOutput(BaseOutput):
|
||||
predicted_image_embedding: torch.Tensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin, AttentionMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
@@ -166,65 +166,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
|
||||
@@ -21,9 +21,9 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttentionModuleMixin,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
@@ -35,6 +35,104 @@ from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SanaAttention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Attention implementation specialized for Sana models.
|
||||
|
||||
This module implements lightweight multi-scale linear attention as used in Sana.
|
||||
"""
|
||||
# Set Sana-specific processor classes
|
||||
default_processor_class = SanaLinearAttnProcessor2_0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 8,
|
||||
mult: float = 1.0,
|
||||
norm_type: str = "batch_norm",
|
||||
kernel_sizes: Tuple[int, ...] = (5,),
|
||||
eps: float = 1e-15,
|
||||
residual_connection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.eps = eps
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
# Calculate dimensions
|
||||
num_attention_heads = (
|
||||
int(in_channels // attention_head_dim * mult) if num_attention_heads is None else num_attention_heads
|
||||
)
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
self.heads = num_attention_heads
|
||||
|
||||
# Query, key, value projections
|
||||
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
|
||||
|
||||
# Multi-scale attention
|
||||
self.to_qkv_multiscale = nn.ModuleList()
|
||||
for kernel_size in kernel_sizes:
|
||||
self.to_qkv_multiscale.append(
|
||||
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.nonlinearity = nn.ReLU()
|
||||
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
|
||||
|
||||
# Get normalization based on type
|
||||
if norm_type == "batch_norm":
|
||||
self.norm_out = nn.BatchNorm1d(out_channels)
|
||||
elif norm_type == "layer_norm":
|
||||
self.norm_out = nn.LayerNorm(out_channels)
|
||||
elif norm_type == "group_norm":
|
||||
self.norm_out = nn.GroupNorm(32, out_channels)
|
||||
elif norm_type == "instance_norm":
|
||||
self.norm_out = nn.InstanceNorm1d(out_channels)
|
||||
else:
|
||||
self.norm_out = nn.Identity()
|
||||
|
||||
# Set processor
|
||||
self.processor = self.default_processor_class()
|
||||
|
||||
|
||||
class SanaMultiscaleAttentionProjection(nn.Module):
|
||||
"""Projection layer for Sana multi-scale attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
kernel_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = 3 * in_channels
|
||||
self.proj_in = nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2,
|
||||
groups=channels,
|
||||
bias=False,
|
||||
)
|
||||
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -289,7 +387,7 @@ class SanaTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
@@ -414,65 +512,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -21,10 +21,8 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention import Attention, AttentionMixin, FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
StableAudioAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -187,7 +185,7 @@ class StableAudioDiTBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
class StableAudioDiTModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
The Diffusion Transformer model introduced in Stable Audio.
|
||||
|
||||
@@ -279,65 +277,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
# Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio
|
||||
def set_default_attn_processor(self):
|
||||
|
||||
@@ -19,11 +19,12 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import LegacyConfigMixin, register_to_config
|
||||
from ...utils import deprecate, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import LegacyModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import BasicTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -36,7 +37,7 @@ class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
class Transformer2DModel(LegacyModelMixin, LegacyConfigMixin, AttentionMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import AllegroAttnProcessor2_0, Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -13,16 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention import Attention, AttentionMixin, FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -130,7 +128,7 @@ class CogView3PlusTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
|
||||
Diffusion](https://huggingface.co/papers/2403.05121).
|
||||
@@ -229,65 +227,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -21,13 +21,13 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -22,11 +22,12 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -13,27 +13,24 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ...models.attention_processor import AttentionModuleMixin
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_torch_xla_version
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -42,6 +39,429 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseFluxAttnProcessor:
|
||||
"""Base attention processor for Flux models with common functionality."""
|
||||
|
||||
compatible_backends = []
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||
if self.is_fused and hasattr(attn, "to_qkv"):
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using standard separate projection matrices."""
|
||||
# Standard separate projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Get projections using fused QKV projection matrices."""
|
||||
# Fused QKV projection
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# Handle encoder projections if present
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
"""Computes the attention. Can be overridden by hardware-specific implementations."""
|
||||
return F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
query, key, value, encoder_projections = self.get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if encoder_projections is not None:
|
||||
encoder_query, encoder_key, encoder_value = encoder_projections
|
||||
encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
# Concatenate for joint attention
|
||||
query = torch.cat([encoder_query, query], dim=2)
|
||||
key = torch.cat([encoder_key, key], dim=2)
|
||||
value = torch.cat([encoder_value, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = self._compute_attention(query, key, value, attention_mask)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessorSDPA(BaseFluxAttnProcessor):
|
||||
compatible_backends = ["cuda", "xpu", "cpu"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class FluxAttnProcessorNPU(BaseFluxAttnProcessor):
|
||||
"""NPU-specific implementation of Flux attention processor."""
|
||||
|
||||
compatible_backends = ["npu"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not is_torch_npu_available():
|
||||
raise ImportError("FluxAttnProcessorNPU requires torch_npu, please install it.")
|
||||
import torch_npu
|
||||
|
||||
self.attn_fn = torch_npu.npu_fusion_attention
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
# NPU-specific implementation
|
||||
return self.attn_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.shape[1], # number of heads
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
# Fall back to standard implementation for other dtypes
|
||||
return F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
|
||||
class FluxAttnProcessorXLA(BaseFluxAttnProcessor):
|
||||
"""XLA-specific implementation of Flux attention processor."""
|
||||
|
||||
compatible_backends = ["xla"]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if not is_torch_xla_available():
|
||||
raise ImportError(
|
||||
"FluxAttnProcessorXLA requires torch_xla, please install it using `pip install torch_xla`"
|
||||
)
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
|
||||
from torch_xla.experimental.custom_kernel import flash_attention
|
||||
|
||||
self.attn_fn = flash_attention
|
||||
|
||||
def _compute_attention(self, query, key, value, attention_mask=None):
|
||||
query /= math.sqrt(query.shape[3])
|
||||
hidden_states = self.attn_fn(query, key, value, causal=False)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterAttnProcessorSDPA(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = FluxAttnProcessorSDPA
|
||||
_available_processors = [
|
||||
FluxAttnProcessorSDPA,
|
||||
FluxAttnProcessorNPU,
|
||||
FluxAttnProcessorXLA,
|
||||
FluxIPAdapterAttnProcessorSDPA,
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.use_bias = bias
|
||||
|
||||
# Query, Key, Value projections
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# RMSNorm for Flux models
|
||||
self.norm_q = RMSNorm(dim_head, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim_head, eps=1e-6)
|
||||
|
||||
# Optional added key/value projections for joint attention
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# Normalization for added projections
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=1e-6)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=1e-6)
|
||||
self.added_proj_bias = bias
|
||||
|
||||
# Output projection for context
|
||||
self.to_add_out = nn.Linear(self.inner_dim, query_dim, bias=bias)
|
||||
|
||||
# Output projection and dropout
|
||||
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, query_dim, bias=bias), nn.Dropout(dropout)])
|
||||
|
||||
# Set processor
|
||||
self.processor = self.set_processor(self.default_processor_cls())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass for Flux attention.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states
|
||||
encoder_hidden_states: Optional encoder hidden states for cross-attention
|
||||
attention_mask: Optional attention mask
|
||||
image_rotary_emb: Optional rotary embeddings for image tokens
|
||||
|
||||
Returns:
|
||||
Output hidden states, and optionally encoder hidden states for joint attention
|
||||
"""
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
@@ -53,27 +473,12 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
)
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -113,18 +518,15 @@ class FluxTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
# Use specialized FluxAttention instead of generic Attention
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
added_kv_proj_dim=dim,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -191,7 +593,7 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
@@ -286,105 +688,9 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -23,8 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..attention import Attention, AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
@@ -36,6 +35,7 @@ from ..embeddings import (
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -819,7 +819,7 @@ class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
@@ -962,65 +962,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
# Using methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -24,13 +24,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -13,28 +13,251 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numbers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..attention_processor import AttentionModuleMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class MochiAttnProcessorSDPA:
|
||||
"""Attention processor used in Mochi."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("MochiAttnProcessorSDPA requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "MochiAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
|
||||
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
||||
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
||||
|
||||
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
||||
|
||||
query = apply_rotary_emb(query, *image_rotary_emb)
|
||||
key = apply_rotary_emb(key, *image_rotary_emb)
|
||||
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = (
|
||||
encoder_query.transpose(1, 2),
|
||||
encoder_key.transpose(1, 2),
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
total_length = sequence_length + encoder_sequence_length
|
||||
|
||||
batch_size, heads, _, dim = query.shape
|
||||
attn_outputs = []
|
||||
for idx in range(batch_size):
|
||||
mask = attention_mask[idx][None, :]
|
||||
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
|
||||
|
||||
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
|
||||
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
|
||||
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
valid_sequence_length = attn_output.size(2)
|
||||
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
|
||||
attn_outputs.append(attn_output)
|
||||
|
||||
hidden_states = torch.cat(attn_outputs, dim=0)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MochiRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiAttention(nn.Module, AttentionModuleMixin):
|
||||
default_processor_cls = MochiAttnProcessorSDPA
|
||||
_available_processors = [MochiAttnProcessorSDPA]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_proj_bias: bool = True,
|
||||
out_dim: Optional[int] = None,
|
||||
out_context_dim: Optional[int] = None,
|
||||
context_pre_only: bool = False,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.heads = heads
|
||||
self.scale = dim_head**-0.5
|
||||
self.use_bias = bias
|
||||
self.scale_qk = True # Always use scaled attention
|
||||
self.context_pre_only = context_pre_only
|
||||
self.eps = eps
|
||||
|
||||
# Set output dimensions
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim else query_dim
|
||||
|
||||
# Self-attention projections
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# Normalization for queries and keys
|
||||
self.norm_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_k = MochiRMSNorm(dim_head, eps, True)
|
||||
|
||||
# Added key/value projections for joint processing
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
|
||||
# Normalization for added projections
|
||||
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
|
||||
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
|
||||
self.added_proj_bias = added_proj_bias
|
||||
|
||||
# Output projections
|
||||
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, self.out_dim, bias=bias), nn.Dropout(dropout)])
|
||||
|
||||
# Context output projection
|
||||
if not context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=added_proj_bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
# Initialize attention processor using the default class
|
||||
self.processor = self.default_processor_class()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MochiModulatedRMSNorm(nn.Module):
|
||||
def __init__(self, eps: float):
|
||||
super().__init__()
|
||||
@@ -175,7 +398,6 @@ class MochiTransformerBlock(nn.Module):
|
||||
out_dim=dim,
|
||||
out_context_dim=pooled_projection_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
processor=MochiAttnProcessor2_0(),
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,15 +20,13 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FusedJointAttnProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
AttentionModuleMixin,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
@@ -36,6 +34,208 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
"""Attention processor used for processing joint attention."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("JointAttnProcessor requires PyTorch 2.0, please upgrade PyTorch.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Project query from hidden states
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
# Self-attention: Use hidden_states for key and value
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
else:
|
||||
# Cross-attention: Use encoder_hidden_states for key and value
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
# Handle additional context for joint attention
|
||||
if hasattr(attn, "added_kv_proj_dim") and attn.added_kv_proj_dim is not None:
|
||||
context_key = attn.add_k_proj(encoder_hidden_states)
|
||||
context_value = attn.add_v_proj(encoder_hidden_states)
|
||||
context_query = attn.add_q_proj(encoder_hidden_states)
|
||||
|
||||
# Joint query, key, value with context
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
context_query = context_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
context_key = context_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
context_value = context_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Concatenate for joint attention
|
||||
query = torch.cat([context_query, query], dim=2)
|
||||
key = torch.cat([context_key, key], dim=2)
|
||||
value = torch.cat([context_value, value], dim=2)
|
||||
|
||||
# Apply joint attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape back to original dimensions
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# Split context and hidden states
|
||||
context_len = encoder_hidden_states.shape[1]
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, :context_len],
|
||||
hidden_states[:, context_len:],
|
||||
)
|
||||
|
||||
# Apply output projections
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if not attn.context_pre_only and hasattr(attn, "to_add_out") and attn.to_add_out is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# Handle standard attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
# Reshape for multi-head attention
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply attention
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Reshape output
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# Apply output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3Attention(nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Specialized attention implementation for SD3 models.
|
||||
|
||||
Features joint attention mechanisms and custom handling of
|
||||
context projections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
out_dim: Optional[int] = None,
|
||||
context_pre_only: bool = False,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Core parameters
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
self.scale_qk = True # SD3 always scales query-key dot products
|
||||
self.use_bias = bias
|
||||
self.context_pre_only = context_pre_only
|
||||
self.eps = eps
|
||||
|
||||
# Set output dimension
|
||||
out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
# Set cross-attention parameters
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
# Linear projections for self-attention
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
# Optional added key/value projections for joint attention
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=bias)
|
||||
self.added_proj_bias = bias
|
||||
|
||||
# Output projection for context
|
||||
if not context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, out_dim, bias=bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
# Output projection and dropout
|
||||
self.to_out = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, out_dim, bias=bias),
|
||||
nn.Dropout(dropout)
|
||||
])
|
||||
|
||||
# Set processor
|
||||
self.processor = JointAttnProcessor()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Forward pass for SD3 attention.
|
||||
|
||||
Args:
|
||||
hidden_states: Input hidden states
|
||||
encoder_hidden_states: Optional encoder hidden states for cross/joint attention
|
||||
attention_mask: Optional attention mask
|
||||
position_ids: Optional position IDs
|
||||
|
||||
Returns:
|
||||
Output hidden states, and optionally encoder hidden states for joint attention
|
||||
"""
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class SD3SingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
@@ -47,13 +247,13 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
# Use specialized SD3Attention instead of generic Attention
|
||||
self.attn = SD3Attention(
|
||||
query_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
processor=JointAttnProcessor2_0(),
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
@@ -78,7 +278,7 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
|
||||
@@ -214,105 +414,9 @@ class SD3Transformer2DModel(
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, None, 0)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -19,10 +19,11 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
from .modeling_common import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,7 +39,7 @@ class TransformerTemporalModelOutput(BaseOutput):
|
||||
sample: torch.Tensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
@@ -202,7 +203,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
class TransformerSpatioTemporalModel(nn.Module, AttentionMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
|
||||
@@ -22,13 +22,13 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from .modeling_common import FeedForward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -21,7 +21,8 @@ from torch import nn
|
||||
from ...utils import deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..normalization import AdaGroupNorm
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
|
||||
@@ -21,7 +21,8 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor
|
||||
from ..attention import Attention
|
||||
from ..attention_processor import AttentionProcessor, AttnProcessor
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -41,6 +40,7 @@ from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.modelling_common import BasicTransformerBlock
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention_processor import Attention
|
||||
from ..attention import Attention
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user