mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
7 Commits
chroma-fin
...
attn-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c87575dde6 | ||
|
|
3cb66e8786 | ||
|
|
2891f14127 | ||
|
|
d07da5da0a | ||
|
|
8f3c7692df | ||
|
|
ec77515f18 | ||
|
|
bd71805e6e |
@@ -11,36 +11,814 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
|
||||
# Import xformers only if it's available
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
except ImportError:
|
||||
xformers = None
|
||||
|
||||
from ..utils import logging
|
||||
from ..utils.import_utils import (
|
||||
is_torch_npu_available,
|
||||
is_torch_xla_available,
|
||||
is_xformers_available,
|
||||
)
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
from .attention_processor import (
|
||||
AttentionProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
class AttentionMixin:
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
||||
ff_output = torch.cat(
|
||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
||||
dim=chunk_dim,
|
||||
)
|
||||
return ff_output
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
"""
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
attn_processor.fused_projections = False
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
fused_projections = False
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
processor (`AttnProcessor`):
|
||||
The attention processor to use.
|
||||
"""
|
||||
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||
# pop `processor` from `self._modules`
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and isinstance(self.processor, torch.nn.Module)
|
||||
and not isinstance(processor, torch.nn.Module)
|
||||
):
|
||||
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||
self._modules.pop("processor")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to return the deprecated LoRA attention processor.
|
||||
|
||||
Returns:
|
||||
"AttentionProcessor": The attention processor in use.
|
||||
"""
|
||||
if not return_deprecated_lora:
|
||||
return self.processor
|
||||
|
||||
def set_attention_backend(self, backend: str):
|
||||
from .attention_dispatch import AttentionBackendName
|
||||
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend.lower())
|
||||
self.processor._attention_backend = backend
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
Set whether to use NPU flash attention from `torch_npu` or not.
|
||||
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
|
||||
if use_npu_flash_attention:
|
||||
if not is_torch_npu_available():
|
||||
raise ImportError("torch_npu is not available")
|
||||
|
||||
self.set_attention_backend("_native_npu")
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self,
|
||||
use_xla_flash_attention: bool,
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use XLA flash attention from `torch_xla` or not.
|
||||
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available():
|
||||
raise ImportError("torch_xla is not available")
|
||||
|
||||
self.set_attention_backend("_native_xla")
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
"""
|
||||
Set whether to use memory efficient attention from `xformers` or not.
|
||||
|
||||
Args:
|
||||
use_memory_efficient_attention_xformers (`bool`):
|
||||
Whether to use memory efficient attention from `xformers` or not.
|
||||
attention_op (`Callable`, *optional*):
|
||||
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
||||
`xformers`.
|
||||
"""
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||
" only available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
if xformers is not None:
|
||||
dtype = None
|
||||
if attention_op is not None:
|
||||
op_fw, op_bw = attention_op
|
||||
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||
_ = xformers.ops.memory_efficient_attention(q, q, q)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.set_attention_backend("xformers")
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
||||
# Fuse cross-attention key-value projections
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
else:
|
||||
# Fuse self-attention projections
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
if hasattr(self, "use_bias") and self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
# Handle added projections for models like SD3, Flux, etc.
|
||||
if (
|
||||
getattr(self, "add_q_proj", None) is not None
|
||||
and getattr(self, "add_k_proj", None) is not None
|
||||
and getattr(self, "add_v_proj", None) is not None
|
||||
):
|
||||
concatenated_weights = torch.cat(
|
||||
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||
)
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
# Remove fused projection layers
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
|
||||
if hasattr(self, "to_added_qkv"):
|
||||
delattr(self, "to_added_qkv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def set_attention_slice(self, slice_size: int) -> None:
|
||||
"""
|
||||
Set the slice size for attention computation.
|
||||
|
||||
Args:
|
||||
slice_size (`int`):
|
||||
The slice size for attention computation.
|
||||
"""
|
||||
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
processor = None
|
||||
|
||||
# Try to get a compatible processor for sliced attention
|
||||
if slice_size is not None:
|
||||
processor = self._get_compatible_processor("sliced")
|
||||
|
||||
# If no processor was found or slice_size is None, use default processor
|
||||
if processor is None:
|
||||
processor = self.default_processor_cls()
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
||||
"""
|
||||
Reshape the tensor for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if tensor.ndim == 3:
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
extra_dim = 1
|
||||
else:
|
||||
batch_size, extra_dim, seq_len, dim = tensor.shape
|
||||
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3)
|
||||
|
||||
if out_dim == 3:
|
||||
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
||||
|
||||
return tensor
|
||||
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
dtype = query.dtype
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
del attention_scores
|
||||
|
||||
attention_probs = attention_probs.to(dtype)
|
||||
|
||||
return attention_probs
|
||||
|
||||
def prepare_attention_mask(
|
||||
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare the attention mask for the attention computation.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
||||
target_length (`int`): The target length of the attention mask.
|
||||
batch_size (`int`): The batch size for repeating the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The prepared attention mask.
|
||||
"""
|
||||
head_size = self.heads
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length != target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
||||
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
||||
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
||||
# remaining_length: int = target_length - current_length
|
||||
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
if out_dim == 3:
|
||||
if attention_mask.shape[0] < batch_size * head_size:
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||
elif out_dim == 4:
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize the encoder hidden states.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The normalized encoder hidden states.
|
||||
"""
|
||||
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||
if isinstance(self.norm_cross, nn.LayerNorm):
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
elif isinstance(self.norm_cross, nn.GroupNorm):
|
||||
# Group norm norms along the channels dimension and expects
|
||||
# input to be in the shape of (N, C, *). In this case, we want
|
||||
# to norm along the hidden dimension, so we need to move
|
||||
# (batch_size, sequence_length, hidden_size) ->
|
||||
# (batch_size, hidden_size, sequence_length)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||
else:
|
||||
assert False
|
||||
|
||||
return encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`):
|
||||
The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8):
|
||||
The number of heads to use for multi-head attention.
|
||||
kv_heads (`int`, *optional*, defaults to `None`):
|
||||
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
||||
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
||||
Query Attention (MQA) otherwise GQA is used.
|
||||
dim_head (`int`, *optional*, defaults to 64):
|
||||
The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
upcast_attention (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the attention computation to `float32`.
|
||||
upcast_softmax (`bool`, *optional*, defaults to False):
|
||||
Set to `True` to upcast the softmax computation to `float32`.
|
||||
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
||||
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
||||
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the group norm in the cross attention.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||
The number of groups to use for the group norm in the attention.
|
||||
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the spatial normalization.
|
||||
out_bias (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to use a bias in the output linear layer.
|
||||
scale_qk (`bool`, *optional*, defaults to `True`):
|
||||
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
||||
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
||||
`added_kv_proj_dim` is not `None`.
|
||||
eps (`float`, *optional*, defaults to 1e-5):
|
||||
An additional value added to the denominator in group normalization that is used for numerical stability.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
||||
A factor to rescale the output by dividing it with this value.
|
||||
residual_connection (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` to add the residual connection to the output.
|
||||
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
||||
Set to `True` if the attention block is loaded from a deprecated state dict.
|
||||
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
||||
The attention processor to use. If `None`, defaults to `AttnProcessorSDPA` if `torch 2.x` is used and
|
||||
`AttnProcessor` otherwise.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
scale_qk: bool = True,
|
||||
only_cross_attention: bool = False,
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
is_causal: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import.
|
||||
from .normalization import FP32LayerNorm, LpNorm
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
self.rescale_output_factor = rescale_output_factor
|
||||
self.residual_connection = residual_connection
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
self.is_causal = is_causal
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.only_cross_attention = only_cross_attention
|
||||
|
||||
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||
raise ValueError(
|
||||
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||
)
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
if spatial_norm_dim is not None:
|
||||
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
||||
else:
|
||||
self.spatial_norm = None
|
||||
|
||||
if qk_norm is None:
|
||||
self.norm_q = None
|
||||
self.norm_k = None
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "layer_norm_across_heads":
|
||||
# Lumina applies qk norm across all heads
|
||||
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# LTX applies qk norm across all heads
|
||||
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||
elif qk_norm == "l2":
|
||||
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
||||
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
||||
)
|
||||
|
||||
if cross_attention_norm is None:
|
||||
self.norm_cross = None
|
||||
elif cross_attention_norm == "layer_norm":
|
||||
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
||||
elif cross_attention_norm == "group_norm":
|
||||
if self.added_kv_proj_dim is not None:
|
||||
# The given `encoder_hidden_states` are initially of shape
|
||||
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
||||
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
||||
# before the projection, so we need to use `added_kv_proj_dim` as
|
||||
# the number of channels for the group norm.
|
||||
norm_cross_num_channels = added_kv_proj_dim
|
||||
else:
|
||||
norm_cross_num_channels = self.cross_attention_dim
|
||||
|
||||
self.norm_cross = nn.GroupNorm(
|
||||
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
self.added_proj_bias = added_proj_bias
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
else:
|
||||
self.add_q_proj = None
|
||||
self.add_k_proj = None
|
||||
self.add_v_proj = None
|
||||
|
||||
if not self.pre_only:
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
else:
|
||||
self.to_out = None
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||
else:
|
||||
self.to_add_out = None
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "layer_norm":
|
||||
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "rms_norm":
|
||||
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "rms_norm_across_heads":
|
||||
# Wan applies qk norm across all heads
|
||||
# Wan also doesn't apply a q norm
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
||||
)
|
||||
else:
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**cross_attention_kwargs,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
The forward method of the `Attention` class.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
The hidden states of the query.
|
||||
encoder_hidden_states (`torch.Tensor`, *optional*):
|
||||
The hidden states of the encoder.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
The attention mask to use. If `None`, no mask is applied.
|
||||
**cross_attention_kwargs:
|
||||
Additional keyword arguments to pass along to the cross attention.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The output of the attention layer.
|
||||
"""
|
||||
# The `Attention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [
|
||||
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||
]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
|
||||
1098
src/diffusers/models/attention_dispatch.py
Normal file
1098
src/diffusers/models/attention_dispatch.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,25 +13,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ..attention import Attention, AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -42,6 +36,270 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxAttnProcessor:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
encoder_projections = None
|
||||
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||
|
||||
return query, key, value, encoder_projections
|
||||
|
||||
def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||
if attn.is_fused and hasattr(attn, "to_qkv"):
|
||||
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if encoder_projections is not None:
|
||||
encoder_query, encoder_key, encoder_value = encoder_projections
|
||||
encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
# Concatenate for joint attention
|
||||
query = torch.cat([encoder_query, query], dim=2)
|
||||
key = torch.cat([encoder_key, key], dim=2)
|
||||
value = torch.cat([encoder_value, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "FluxAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxAttention(Attention):
|
||||
_default_processor_cls = FluxAttnProcessor
|
||||
_available_processors = [
|
||||
FluxAttnProcessor,
|
||||
FluxIPAdapterAttnProcessor,
|
||||
]
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||
@@ -53,24 +311,13 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
if is_torch_npu_available():
|
||||
deprecation_message = (
|
||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
||||
"should be set explicitly using the `set_attn_processor` method."
|
||||
)
|
||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
||||
processor = FluxAttnProcessor2_0_NPU()
|
||||
else:
|
||||
processor = FluxAttnProcessor2_0()
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
out_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
qk_norm="rms_norm",
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
@@ -113,18 +360,16 @@ class FluxTransformerBlock(nn.Module):
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
|
||||
self.attn = Attention(
|
||||
self.attn = FluxAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=FluxAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
added_kv_proj_dim=dim,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
@@ -191,7 +436,13 @@ class FluxTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
FluxTransformer2DLoadersMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
@@ -286,105 +537,9 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
# Using inherited methods from AttentionMixin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user