mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 00:44:51 +08:00
Compare commits
8 Commits
add-uv-scr
...
flux-attn-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
941911c538 | ||
|
|
c87575dde6 | ||
|
|
3cb66e8786 | ||
|
|
2891f14127 | ||
|
|
d07da5da0a | ||
|
|
8f3c7692df | ||
|
|
ec77515f18 | ||
|
|
bd71805e6e |
@@ -11,36 +11,816 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
import inspect
|
||||||
|
from typing import Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..utils import deprecate, logging
|
|
||||||
|
# Import xformers only if it's available
|
||||||
|
try:
|
||||||
|
import xformers
|
||||||
|
import xformers.ops
|
||||||
|
except ImportError:
|
||||||
|
xformers = None
|
||||||
|
|
||||||
|
from ..utils import logging
|
||||||
|
from ..utils.import_utils import (
|
||||||
|
is_torch_npu_available,
|
||||||
|
is_torch_xla_available,
|
||||||
|
is_xformers_available,
|
||||||
|
)
|
||||||
from ..utils.torch_utils import maybe_allow_in_graph
|
from ..utils.torch_utils import maybe_allow_in_graph
|
||||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
from .attention_processor import (
|
||||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
AttentionProcessor,
|
||||||
from .embeddings import SinusoidalPositionalEmbedding
|
AttnProcessor,
|
||||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
)
|
||||||
|
from .normalization import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
class AttentionMixin:
|
||||||
# "feed_forward_chunk_size" can be used to save memory
|
@property
|
||||||
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor()
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
)
|
)
|
||||||
|
|
||||||
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
ff_output = torch.cat(
|
if hasattr(module, "set_processor"):
|
||||||
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
if not isinstance(processor, dict):
|
||||||
dim=chunk_dim,
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||||
|
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, AttentionModuleMixin):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
attn_processor.fused_projections = False
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionModuleMixin:
|
||||||
|
_default_processor_cls = None
|
||||||
|
_available_processors = []
|
||||||
|
fused_projections = False
|
||||||
|
|
||||||
|
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||||
|
"""
|
||||||
|
Set the attention processor to use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor (`AttnProcessor`):
|
||||||
|
The attention processor to use.
|
||||||
|
"""
|
||||||
|
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
||||||
|
# pop `processor` from `self._modules`
|
||||||
|
if (
|
||||||
|
hasattr(self, "processor")
|
||||||
|
and isinstance(self.processor, torch.nn.Module)
|
||||||
|
and not isinstance(processor, torch.nn.Module)
|
||||||
|
):
|
||||||
|
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
||||||
|
self._modules.pop("processor")
|
||||||
|
|
||||||
|
self.processor = processor
|
||||||
|
|
||||||
|
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||||
|
"""
|
||||||
|
Get the attention processor in use.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
||||||
|
Set to `True` to return the deprecated LoRA attention processor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"AttentionProcessor": The attention processor in use.
|
||||||
|
"""
|
||||||
|
if not return_deprecated_lora:
|
||||||
|
return self.processor
|
||||||
|
|
||||||
|
def set_attention_backend(self, backend: str):
|
||||||
|
from .attention_dispatch import AttentionBackendName
|
||||||
|
|
||||||
|
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||||
|
if backend not in available_backends:
|
||||||
|
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||||
|
|
||||||
|
backend = AttentionBackendName(backend.lower())
|
||||||
|
self.processor._attention_backend = backend
|
||||||
|
|
||||||
|
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set whether to use NPU flash attention from `torch_npu` or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if use_npu_flash_attention:
|
||||||
|
if not is_torch_npu_available():
|
||||||
|
raise ImportError("torch_npu is not available")
|
||||||
|
|
||||||
|
self.set_attention_backend("_native_npu")
|
||||||
|
|
||||||
|
def set_use_xla_flash_attention(
|
||||||
|
self,
|
||||||
|
use_xla_flash_attention: bool,
|
||||||
|
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||||
|
is_flux=False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set whether to use XLA flash attention from `torch_xla` or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_xla_flash_attention (`bool`):
|
||||||
|
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||||
|
partition_spec (`Tuple[]`, *optional*):
|
||||||
|
Specify the partition specification if using SPMD. Otherwise None.
|
||||||
|
is_flux (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether the model is a Flux model.
|
||||||
|
"""
|
||||||
|
if use_xla_flash_attention:
|
||||||
|
if not is_torch_xla_available():
|
||||||
|
raise ImportError("torch_xla is not available")
|
||||||
|
|
||||||
|
self.set_attention_backend("_native_xla")
|
||||||
|
|
||||||
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Set whether to use memory efficient attention from `xformers` or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_memory_efficient_attention_xformers (`bool`):
|
||||||
|
Whether to use memory efficient attention from `xformers` or not.
|
||||||
|
attention_op (`Callable`, *optional*):
|
||||||
|
The attention operation to use. Defaults to `None` which uses the default attention operation from
|
||||||
|
`xformers`.
|
||||||
|
"""
|
||||||
|
if use_memory_efficient_attention_xformers:
|
||||||
|
if not is_xformers_available():
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
|
||||||
|
name="xformers",
|
||||||
|
)
|
||||||
|
elif not torch.cuda.is_available():
|
||||||
|
raise ValueError(
|
||||||
|
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
||||||
|
" only available for GPU "
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
# Make sure we can run the memory efficient attention
|
||||||
|
if xformers is not None:
|
||||||
|
dtype = None
|
||||||
|
if attention_op is not None:
|
||||||
|
op_fw, op_bw = attention_op
|
||||||
|
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
||||||
|
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
||||||
|
_ = xformers.ops.memory_efficient_attention(q, q, q)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
self.set_attention_backend("xformers")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse_projections(self):
|
||||||
|
"""
|
||||||
|
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||||
|
"""
|
||||||
|
# Skip if already fused
|
||||||
|
if getattr(self, "fused_projections", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
device = self.to_q.weight.data.device
|
||||||
|
dtype = self.to_q.weight.data.dtype
|
||||||
|
|
||||||
|
if hasattr(self, "is_cross_attention") and self.is_cross_attention:
|
||||||
|
# Fuse cross-attention key-value projections
|
||||||
|
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||||
|
in_features = concatenated_weights.shape[1]
|
||||||
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
|
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||||
|
self.to_kv.weight.copy_(concatenated_weights)
|
||||||
|
if hasattr(self, "use_bias") and self.use_bias:
|
||||||
|
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||||
|
self.to_kv.bias.copy_(concatenated_bias)
|
||||||
|
else:
|
||||||
|
# Fuse self-attention projections
|
||||||
|
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||||
|
in_features = concatenated_weights.shape[1]
|
||||||
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
|
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||||
|
self.to_qkv.weight.copy_(concatenated_weights)
|
||||||
|
if hasattr(self, "use_bias") and self.use_bias:
|
||||||
|
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||||
|
self.to_qkv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
|
# Handle added projections for models like SD3, Flux, etc.
|
||||||
|
if (
|
||||||
|
getattr(self, "add_q_proj", None) is not None
|
||||||
|
and getattr(self, "add_k_proj", None) is not None
|
||||||
|
and getattr(self, "add_v_proj", None) is not None
|
||||||
|
):
|
||||||
|
concatenated_weights = torch.cat(
|
||||||
|
[self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
|
||||||
|
)
|
||||||
|
in_features = concatenated_weights.shape[1]
|
||||||
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
|
self.to_added_qkv = nn.Linear(
|
||||||
|
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||||
|
if self.added_proj_bias:
|
||||||
|
concatenated_bias = torch.cat(
|
||||||
|
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||||
|
)
|
||||||
|
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
|
self.fused_projections = True
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def unfuse_projections(self):
|
||||||
|
"""
|
||||||
|
Unfuse the query, key, and value projections back to separate projections.
|
||||||
|
"""
|
||||||
|
# Skip if not fused
|
||||||
|
if not getattr(self, "fused_projections", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove fused projection layers
|
||||||
|
if hasattr(self, "to_qkv"):
|
||||||
|
delattr(self, "to_qkv")
|
||||||
|
|
||||||
|
if hasattr(self, "to_kv"):
|
||||||
|
delattr(self, "to_kv")
|
||||||
|
|
||||||
|
if hasattr(self, "to_added_qkv"):
|
||||||
|
delattr(self, "to_added_qkv")
|
||||||
|
|
||||||
|
self.fused_projections = False
|
||||||
|
|
||||||
|
def set_attention_slice(self, slice_size: int) -> None:
|
||||||
|
"""
|
||||||
|
Set the slice size for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`int`):
|
||||||
|
The slice size for attention computation.
|
||||||
|
"""
|
||||||
|
if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||||
|
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||||
|
|
||||||
|
processor = None
|
||||||
|
|
||||||
|
# Try to get a compatible processor for sliced attention
|
||||||
|
if slice_size is not None:
|
||||||
|
processor = self._get_compatible_processor("sliced")
|
||||||
|
|
||||||
|
# If no processor was found or slice_size is None, use default processor
|
||||||
|
if processor is None:
|
||||||
|
processor = self.default_processor_cls()
|
||||||
|
|
||||||
|
self.set_processor(processor)
|
||||||
|
|
||||||
|
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`): The tensor to reshape.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The reshaped tensor.
|
||||||
|
"""
|
||||||
|
head_size = self.heads
|
||||||
|
batch_size, seq_len, dim = tensor.shape
|
||||||
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reshape the tensor for multi-head attention processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`): The tensor to reshape.
|
||||||
|
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The reshaped tensor.
|
||||||
|
"""
|
||||||
|
head_size = self.heads
|
||||||
|
if tensor.ndim == 3:
|
||||||
|
batch_size, seq_len, dim = tensor.shape
|
||||||
|
extra_dim = 1
|
||||||
|
else:
|
||||||
|
batch_size, extra_dim, seq_len, dim = tensor.shape
|
||||||
|
tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
if out_dim == 3:
|
||||||
|
tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def get_attention_scores(
|
||||||
|
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute the attention scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (`torch.Tensor`): The query tensor.
|
||||||
|
key (`torch.Tensor`): The key tensor.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The attention probabilities/scores.
|
||||||
|
"""
|
||||||
|
dtype = query.dtype
|
||||||
|
if self.upcast_attention:
|
||||||
|
query = query.float()
|
||||||
|
key = key.float()
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
baddbmm_input = torch.empty(
|
||||||
|
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||||
|
)
|
||||||
|
beta = 0
|
||||||
|
else:
|
||||||
|
baddbmm_input = attention_mask
|
||||||
|
beta = 1
|
||||||
|
|
||||||
|
attention_scores = torch.baddbmm(
|
||||||
|
baddbmm_input,
|
||||||
|
query,
|
||||||
|
key.transpose(-1, -2),
|
||||||
|
beta=beta,
|
||||||
|
alpha=self.scale,
|
||||||
|
)
|
||||||
|
del baddbmm_input
|
||||||
|
|
||||||
|
if self.upcast_softmax:
|
||||||
|
attention_scores = attention_scores.float()
|
||||||
|
|
||||||
|
attention_probs = attention_scores.softmax(dim=-1)
|
||||||
|
del attention_scores
|
||||||
|
|
||||||
|
attention_probs = attention_probs.to(dtype)
|
||||||
|
|
||||||
|
return attention_probs
|
||||||
|
|
||||||
|
def prepare_attention_mask(
|
||||||
|
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Prepare the attention mask for the attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
||||||
|
target_length (`int`): The target length of the attention mask.
|
||||||
|
batch_size (`int`): The batch size for repeating the attention mask.
|
||||||
|
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The prepared attention mask.
|
||||||
|
"""
|
||||||
|
head_size = self.heads
|
||||||
|
if attention_mask is None:
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
current_length: int = attention_mask.shape[-1]
|
||||||
|
if current_length != target_length:
|
||||||
|
if attention_mask.device.type == "mps":
|
||||||
|
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||||
|
# Instead, we can manually construct the padding tensor.
|
||||||
|
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
||||||
|
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||||
|
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
||||||
|
else:
|
||||||
|
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
||||||
|
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
||||||
|
# remaining_length: int = target_length - current_length
|
||||||
|
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
||||||
|
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||||
|
|
||||||
|
if out_dim == 3:
|
||||||
|
if attention_mask.shape[0] < batch_size * head_size:
|
||||||
|
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
||||||
|
elif out_dim == 4:
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize the encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The normalized encoder hidden states.
|
||||||
|
"""
|
||||||
|
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||||
|
if isinstance(self.norm_cross, nn.LayerNorm):
|
||||||
|
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||||
|
elif isinstance(self.norm_cross, nn.GroupNorm):
|
||||||
|
# Group norm norms along the channels dimension and expects
|
||||||
|
# input to be in the shape of (N, C, *). In this case, we want
|
||||||
|
# to norm along the hidden dimension, so we need to move
|
||||||
|
# (batch_size, sequence_length, hidden_size) ->
|
||||||
|
# (batch_size, hidden_size, sequence_length)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||||
|
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class Attention(nn.Module, AttentionModuleMixin):
|
||||||
|
r"""
|
||||||
|
A cross attention layer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
query_dim (`int`):
|
||||||
|
The number of channels in the query.
|
||||||
|
cross_attention_dim (`int`, *optional*):
|
||||||
|
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||||
|
heads (`int`, *optional*, defaults to 8):
|
||||||
|
The number of heads to use for multi-head attention.
|
||||||
|
kv_heads (`int`, *optional*, defaults to `None`):
|
||||||
|
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If
|
||||||
|
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi
|
||||||
|
Query Attention (MQA) otherwise GQA is used.
|
||||||
|
dim_head (`int`, *optional*, defaults to 64):
|
||||||
|
The number of channels in each head.
|
||||||
|
dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout probability to use.
|
||||||
|
bias (`bool`, *optional*, defaults to False):
|
||||||
|
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||||
|
upcast_attention (`bool`, *optional*, defaults to False):
|
||||||
|
Set to `True` to upcast the attention computation to `float32`.
|
||||||
|
upcast_softmax (`bool`, *optional*, defaults to False):
|
||||||
|
Set to `True` to upcast the softmax computation to `float32`.
|
||||||
|
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
||||||
|
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the group norm in the cross attention.
|
||||||
|
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to `None`):
|
||||||
|
The number of groups to use for the group norm in the attention.
|
||||||
|
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The number of channels to use for the spatial normalization.
|
||||||
|
out_bias (`bool`, *optional*, defaults to `True`):
|
||||||
|
Set to `True` to use a bias in the output linear layer.
|
||||||
|
scale_qk (`bool`, *optional*, defaults to `True`):
|
||||||
|
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
||||||
|
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
||||||
|
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
||||||
|
`added_kv_proj_dim` is not `None`.
|
||||||
|
eps (`float`, *optional*, defaults to 1e-5):
|
||||||
|
An additional value added to the denominator in group normalization that is used for numerical stability.
|
||||||
|
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
||||||
|
A factor to rescale the output by dividing it with this value.
|
||||||
|
residual_connection (`bool`, *optional*, defaults to `False`):
|
||||||
|
Set to `True` to add the residual connection to the output.
|
||||||
|
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
||||||
|
Set to `True` if the attention block is loaded from a deprecated state dict.
|
||||||
|
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
||||||
|
The attention processor to use. If `None`, defaults to `AttnProcessorSDPA` if `torch 2.x` is used and
|
||||||
|
`AttnProcessor` otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
cross_attention_dim: Optional[int] = None,
|
||||||
|
heads: int = 8,
|
||||||
|
kv_heads: Optional[int] = None,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
upcast_softmax: bool = False,
|
||||||
|
cross_attention_norm: Optional[str] = None,
|
||||||
|
cross_attention_norm_num_groups: int = 32,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
added_kv_proj_dim: Optional[int] = None,
|
||||||
|
added_proj_bias: Optional[bool] = True,
|
||||||
|
norm_num_groups: Optional[int] = None,
|
||||||
|
spatial_norm_dim: Optional[int] = None,
|
||||||
|
out_bias: bool = True,
|
||||||
|
scale_qk: bool = True,
|
||||||
|
only_cross_attention: bool = False,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
rescale_output_factor: float = 1.0,
|
||||||
|
residual_connection: bool = False,
|
||||||
|
_from_deprecated_attn_block: bool = False,
|
||||||
|
processor: Optional["AttnProcessor"] = None,
|
||||||
|
out_dim: int = None,
|
||||||
|
out_context_dim: int = None,
|
||||||
|
context_pre_only=None,
|
||||||
|
pre_only=False,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# To prevent circular import.
|
||||||
|
from .normalization import FP32LayerNorm, LpNorm
|
||||||
|
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.use_bias = bias
|
||||||
|
self.is_cross_attention = cross_attention_dim is not None
|
||||||
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
|
self.upcast_attention = upcast_attention
|
||||||
|
self.upcast_softmax = upcast_softmax
|
||||||
|
self.rescale_output_factor = rescale_output_factor
|
||||||
|
self.residual_connection = residual_connection
|
||||||
|
self.dropout = dropout
|
||||||
|
self.fused_projections = False
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
self.pre_only = pre_only
|
||||||
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
# we make use of this private variable to know whether this class is loaded
|
||||||
|
# with an deprecated state dict so that we can convert it on the fly
|
||||||
|
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||||
|
|
||||||
|
self.scale_qk = scale_qk
|
||||||
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||||
|
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
# for slice_size > 0 the attention score computation
|
||||||
|
# is split across the batch axis to save memory
|
||||||
|
# You can set slice_size with `set_attention_slice`
|
||||||
|
self.sliceable_head_dim = heads
|
||||||
|
|
||||||
|
self.added_kv_proj_dim = added_kv_proj_dim
|
||||||
|
self.only_cross_attention = only_cross_attention
|
||||||
|
|
||||||
|
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
||||||
|
raise ValueError(
|
||||||
|
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if norm_num_groups is not None:
|
||||||
|
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||||
|
else:
|
||||||
|
self.group_norm = None
|
||||||
|
|
||||||
|
if spatial_norm_dim is not None:
|
||||||
|
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
||||||
|
else:
|
||||||
|
self.spatial_norm = None
|
||||||
|
|
||||||
|
if qk_norm is None:
|
||||||
|
self.norm_q = None
|
||||||
|
self.norm_k = None
|
||||||
|
elif qk_norm == "layer_norm":
|
||||||
|
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
elif qk_norm == "fp32_layer_norm":
|
||||||
|
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||||
|
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||||
|
elif qk_norm == "layer_norm_across_heads":
|
||||||
|
# Lumina applies qk norm across all heads
|
||||||
|
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||||
|
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||||
|
elif qk_norm == "rms_norm":
|
||||||
|
self.norm_q = RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_k = RMSNorm(dim_head, eps=eps)
|
||||||
|
elif qk_norm == "rms_norm_across_heads":
|
||||||
|
# LTX applies qk norm across all heads
|
||||||
|
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
|
||||||
|
self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||||
|
elif qk_norm == "l2":
|
||||||
|
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
||||||
|
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cross_attention_norm is None:
|
||||||
|
self.norm_cross = None
|
||||||
|
elif cross_attention_norm == "layer_norm":
|
||||||
|
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
||||||
|
elif cross_attention_norm == "group_norm":
|
||||||
|
if self.added_kv_proj_dim is not None:
|
||||||
|
# The given `encoder_hidden_states` are initially of shape
|
||||||
|
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
||||||
|
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
||||||
|
# before the projection, so we need to use `added_kv_proj_dim` as
|
||||||
|
# the number of channels for the group norm.
|
||||||
|
norm_cross_num_channels = added_kv_proj_dim
|
||||||
|
else:
|
||||||
|
norm_cross_num_channels = self.cross_attention_dim
|
||||||
|
|
||||||
|
self.norm_cross = nn.GroupNorm(
|
||||||
|
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
|
if not self.only_cross_attention:
|
||||||
|
# only relevant for the `AddedKVProcessor` classes
|
||||||
|
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
|
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||||
|
else:
|
||||||
|
self.to_k = None
|
||||||
|
self.to_v = None
|
||||||
|
|
||||||
|
self.added_proj_bias = added_proj_bias
|
||||||
|
if self.added_kv_proj_dim is not None:
|
||||||
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||||
|
if self.context_pre_only is not None:
|
||||||
|
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
|
else:
|
||||||
|
self.add_q_proj = None
|
||||||
|
self.add_k_proj = None
|
||||||
|
self.add_v_proj = None
|
||||||
|
|
||||||
|
if not self.pre_only:
|
||||||
|
self.to_out = nn.ModuleList([])
|
||||||
|
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||||
|
self.to_out.append(nn.Dropout(dropout))
|
||||||
|
else:
|
||||||
|
self.to_out = None
|
||||||
|
|
||||||
|
if self.context_pre_only is not None and not self.context_pre_only:
|
||||||
|
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||||
|
else:
|
||||||
|
self.to_add_out = None
|
||||||
|
|
||||||
|
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||||
|
if qk_norm == "layer_norm":
|
||||||
|
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
elif qk_norm == "fp32_layer_norm":
|
||||||
|
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||||
|
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||||
|
elif qk_norm == "rms_norm":
|
||||||
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
||||||
|
elif qk_norm == "rms_norm_across_heads":
|
||||||
|
# Wan applies qk norm across all heads
|
||||||
|
# Wan also doesn't apply a q norm
|
||||||
|
self.norm_added_q = None
|
||||||
|
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.norm_added_q = None
|
||||||
|
self.norm_added_k = None
|
||||||
|
|
||||||
|
# set attention processor
|
||||||
|
# We use the AttnProcessorSDPA by default when torch 2.x is used which uses
|
||||||
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||||
|
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
||||||
|
if processor is None:
|
||||||
|
processor = self._default_processor_cls
|
||||||
|
self.set_processor(processor)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
**cross_attention_kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
r"""
|
||||||
|
The forward method of the `Attention` class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.Tensor`):
|
||||||
|
The hidden states of the query.
|
||||||
|
encoder_hidden_states (`torch.Tensor`, *optional*):
|
||||||
|
The hidden states of the encoder.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*):
|
||||||
|
The attention mask to use. If `None`, no mask is applied.
|
||||||
|
**cross_attention_kwargs:
|
||||||
|
Additional keyword arguments to pass along to the cross attention.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.Tensor`: The output of the attention layer.
|
||||||
|
"""
|
||||||
|
# The `Attention` class can call different attention processors / attention functions
|
||||||
|
# here we simply pass along all tensors to the selected processor class
|
||||||
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||||
|
|
||||||
|
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||||
|
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||||
|
unused_kwargs = [
|
||||||
|
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
|
||||||
|
]
|
||||||
|
if len(unused_kwargs) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||||
|
)
|
||||||
|
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||||
|
|
||||||
|
return self.processor(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
**cross_attention_kwargs,
|
||||||
)
|
)
|
||||||
return ff_output
|
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
|
|||||||
1098
src/diffusers/models/attention_dispatch.py
Normal file
1098
src/diffusers/models/attention_dispatch.py
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1244,31 +1244,21 @@ class FluxPosEmbed(nn.Module):
|
|||||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||||
def __init__(self, theta: int, axes_dim: List[int]):
|
def __init__(self, theta: int, axes_dim: List[int]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
from .transformers.transformer_flux import FluxPosEmbed as FluxPosEmbed_
|
||||||
|
|
||||||
|
deprecate(
|
||||||
|
"FluxPosEmbed",
|
||||||
|
"1.0.0",
|
||||||
|
"Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please use `FluxPosEmbed` from `diffusers.models.transformers.transformer_flux` instead.",
|
||||||
|
)
|
||||||
|
|
||||||
self.theta = theta
|
self.theta = theta
|
||||||
self.axes_dim = axes_dim
|
self.axes_dim = axes_dim
|
||||||
|
self._rope = FluxPosEmbed_(theta, axes_dim)
|
||||||
|
|
||||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
n_axes = ids.shape[-1]
|
return self._rope(ids)
|
||||||
cos_out = []
|
|
||||||
sin_out = []
|
|
||||||
pos = ids.float()
|
|
||||||
is_mps = ids.device.type == "mps"
|
|
||||||
is_npu = ids.device.type == "npu"
|
|
||||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
|
||||||
for i in range(n_axes):
|
|
||||||
cos, sin = get_1d_rotary_pos_embed(
|
|
||||||
self.axes_dim[i],
|
|
||||||
pos[:, i],
|
|
||||||
theta=self.theta,
|
|
||||||
repeat_interleave_real=True,
|
|
||||||
use_real=True,
|
|
||||||
freqs_dtype=freqs_dtype,
|
|
||||||
)
|
|
||||||
cos_out.append(cos)
|
|
||||||
sin_out.append(sin)
|
|
||||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
|
||||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
|
||||||
return freqs_cos, freqs_sin
|
|
||||||
|
|
||||||
|
|
||||||
class TimestepEmbedding(nn.Module):
|
class TimestepEmbedding(nn.Module):
|
||||||
|
|||||||
@@ -599,6 +599,50 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_attention_backend(self, backend: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the attention backend for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend (`str`):
|
||||||
|
The name of the backend to set. Must be one of the available backends defined in
|
||||||
|
`AttentionBackendName`. Available backends can be found in
|
||||||
|
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
|
||||||
|
attention as backend.
|
||||||
|
"""
|
||||||
|
from .attention import AttentionModuleMixin
|
||||||
|
from .attention_dispatch import AttentionBackendName
|
||||||
|
|
||||||
|
backend = backend.lower()
|
||||||
|
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||||
|
if backend not in available_backends:
|
||||||
|
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||||
|
backend = AttentionBackendName(backend)
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if not isinstance(module, AttentionModuleMixin):
|
||||||
|
continue
|
||||||
|
processor = module.processor
|
||||||
|
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||||
|
continue
|
||||||
|
processor._attention_backend = backend
|
||||||
|
|
||||||
|
def reset_attention_backend(self) -> None:
|
||||||
|
"""
|
||||||
|
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
||||||
|
the torch native scaled dot product attention.
|
||||||
|
"""
|
||||||
|
from .attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
|
attention_classes = (Attention, MochiAttention)
|
||||||
|
for module in self.modules():
|
||||||
|
if not isinstance(module, attention_classes):
|
||||||
|
continue
|
||||||
|
processor = module.processor
|
||||||
|
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||||
|
continue
|
||||||
|
processor._attention_backend = None
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
|
|||||||
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
1251
src/diffusers/models/transformers/modeling_common.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -13,27 +13,26 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||||
from ...utils.import_utils import is_torch_npu_available
|
|
||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention import FeedForward
|
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||||
from ..attention_processor import (
|
from ..attention_dispatch import dispatch_attention_fn
|
||||||
Attention,
|
|
||||||
AttentionProcessor,
|
|
||||||
FluxAttnProcessor2_0,
|
|
||||||
FluxAttnProcessor2_0_NPU,
|
|
||||||
FusedFluxAttnProcessor2_0,
|
|
||||||
)
|
|
||||||
from ..cache_utils import CacheMixin
|
from ..cache_utils import CacheMixin
|
||||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
from ..embeddings import (
|
||||||
|
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||||
|
CombinedTimestepTextProjEmbeddings,
|
||||||
|
apply_rotary_emb,
|
||||||
|
get_1d_rotary_pos_embed,
|
||||||
|
)
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||||
@@ -42,6 +41,323 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
|
|||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class FluxAttnProcessor:
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||||
|
|
||||||
|
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
encoder_projections = None
|
||||||
|
if encoder_hidden_states is not None and hasattr(attn, "add_q_proj"):
|
||||||
|
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||||
|
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||||
|
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||||
|
|
||||||
|
return query, key, value, encoder_projections
|
||||||
|
|
||||||
|
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||||
|
qkv = attn.to_qkv(hidden_states)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
encoder_projections = None
|
||||||
|
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
|
||||||
|
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||||
|
split_size = encoder_qkv.shape[-1] // 3
|
||||||
|
encoder_query, encoder_key, encoder_value = torch.split(encoder_qkv, split_size, dim=-1)
|
||||||
|
encoder_projections = (encoder_query, encoder_key, encoder_value)
|
||||||
|
|
||||||
|
return query, key, value, encoder_projections
|
||||||
|
|
||||||
|
def get_qkv_projections(self, attn, hidden_states, encoder_hidden_states=None):
|
||||||
|
"""Public method to get projections based on whether we're using fused mode or not."""
|
||||||
|
if attn.is_fused and hasattr(attn, "to_qkv"):
|
||||||
|
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
|
||||||
|
return self._get_projections(attn, hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: "FluxAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
|
||||||
|
query, key, value, encoder_projections = self.get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
if encoder_projections is not None:
|
||||||
|
encoder_query, encoder_key, encoder_value = encoder_projections
|
||||||
|
encoder_query = encoder_query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
encoder_key = encoder_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
encoder_value = encoder_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_added_q is not None:
|
||||||
|
encoder_query = attn.norm_added_q(encoder_query)
|
||||||
|
if attn.norm_added_k is not None:
|
||||||
|
encoder_key = attn.norm_added_k(encoder_key)
|
||||||
|
|
||||||
|
# Concatenate for joint attention
|
||||||
|
query = torch.cat([encoder_query, query], dim=2)
|
||||||
|
key = torch.cat([encoder_key, key], dim=2)
|
||||||
|
value = torch.cat([encoder_value, value], dim=2)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = dispatch_attention_fn(query, key, value, attn_mask=attention_mask)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states, hidden_states = (
|
||||||
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||||
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||||
|
"""Flux Attention processor for IP-Adapter."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.cross_attention_dim = cross_attention_dim
|
||||||
|
|
||||||
|
if not isinstance(num_tokens, (tuple, list)):
|
||||||
|
num_tokens = [num_tokens]
|
||||||
|
|
||||||
|
if not isinstance(scale, list):
|
||||||
|
scale = [scale] * len(num_tokens)
|
||||||
|
if len(scale) != len(num_tokens):
|
||||||
|
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
self.to_k_ip = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
for _ in range(len(num_tokens))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.to_v_ip = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||||
|
for _ in range(len(num_tokens))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: "FluxAttention",
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||||
|
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||||
|
key = attn.to_k(hidden_states)
|
||||||
|
value = attn.to_v(hidden_states)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# `context` projections.
|
||||||
|
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||||
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_added_q is not None:
|
||||||
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||||
|
if attn.norm_added_k is not None:
|
||||||
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||||
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||||
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = dispatch_attention_fn(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
encoder_hidden_states, hidden_states = (
|
||||||
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||||
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
# IP-adapter
|
||||||
|
ip_query = hidden_states_query_proj
|
||||||
|
ip_attn_output = torch.zeros_like(hidden_states)
|
||||||
|
|
||||||
|
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||||
|
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||||
|
):
|
||||||
|
ip_key = to_k_ip(current_ip_hidden_states)
|
||||||
|
ip_value = to_v_ip(current_ip_hidden_states)
|
||||||
|
|
||||||
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
|
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||||
|
batch_size, -1, attn.heads * head_dim
|
||||||
|
)
|
||||||
|
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||||
|
ip_attn_output += scale * current_ip_hidden_states
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
|
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||||
|
_default_processor_cls = FluxAttnProcessor
|
||||||
|
_available_processors = [
|
||||||
|
FluxAttnProcessor,
|
||||||
|
FluxIPAdapterAttnProcessor,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query_dim: int,
|
||||||
|
heads: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
bias: bool = False,
|
||||||
|
qk_norm: Optional[str] = None,
|
||||||
|
added_kv_proj_dim: Optional[int] = None,
|
||||||
|
added_proj_bias: Optional[bool] = True,
|
||||||
|
out_bias: bool = True,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
out_dim: int = None,
|
||||||
|
context_pre_only: Optional[bool] = None,
|
||||||
|
pre_only: bool = False,
|
||||||
|
elementwise_affine: bool = True,
|
||||||
|
processor=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
|
||||||
|
|
||||||
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.query_dim = query_dim
|
||||||
|
self.use_bias = bias
|
||||||
|
self.dropout = dropout
|
||||||
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
self.context_pre_only = context_pre_only
|
||||||
|
self.pre_only = pre_only
|
||||||
|
self.heads = out_dim // dim_head if out_dim is not None else heads
|
||||||
|
self.added_proj_bias = added_proj_bias
|
||||||
|
|
||||||
|
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
||||||
|
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
|
if not self.pre_only:
|
||||||
|
self.to_out = torch.nn.ModuleList([])
|
||||||
|
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||||
|
|
||||||
|
if added_kv_proj_dim is not None:
|
||||||
|
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
|
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
|
||||||
|
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
|
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
|
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||||
|
self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
|
||||||
|
|
||||||
|
if processor is None:
|
||||||
|
processor = self._default_processor_cls()
|
||||||
|
self.set_processor(processor)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class FluxSingleTransformerBlock(nn.Module):
|
class FluxSingleTransformerBlock(nn.Module):
|
||||||
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
|
||||||
@@ -53,24 +369,13 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
self.act_mlp = nn.GELU(approximate="tanh")
|
self.act_mlp = nn.GELU(approximate="tanh")
|
||||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||||
|
|
||||||
if is_torch_npu_available():
|
self.attn = FluxAttention(
|
||||||
deprecation_message = (
|
|
||||||
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
|
||||||
"should be set explicitly using the `set_attn_processor` method."
|
|
||||||
)
|
|
||||||
deprecate("npu_processor", "0.34.0", deprecation_message)
|
|
||||||
processor = FluxAttnProcessor2_0_NPU()
|
|
||||||
else:
|
|
||||||
processor = FluxAttnProcessor2_0()
|
|
||||||
|
|
||||||
self.attn = Attention(
|
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
cross_attention_dim=None,
|
out_dim=dim,
|
||||||
dim_head=attention_head_dim,
|
dim_head=attention_head_dim,
|
||||||
heads=num_attention_heads,
|
heads=num_attention_heads,
|
||||||
out_dim=dim,
|
dropout=0.0,
|
||||||
bias=True,
|
bias=True,
|
||||||
processor=processor,
|
|
||||||
qk_norm="rms_norm",
|
qk_norm="rms_norm",
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
pre_only=True,
|
pre_only=True,
|
||||||
@@ -83,20 +388,19 @@ class FluxSingleTransformerBlock(nn.Module):
|
|||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = hidden_states
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||||
|
|
||||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
||||||
attn_output = self.attn(
|
attn_output = self.attn(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
**joint_attention_kwargs,
|
**joint_attention_kwargs,
|
||||||
)
|
)
|
||||||
|
attn_mlp_hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||||
|
proj_out = self.proj_out(attn_mlp_hidden_states)
|
||||||
|
hidden_states = hidden_states + gate.unsqueeze(1) * proj_out
|
||||||
|
|
||||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
||||||
gate = gate.unsqueeze(1)
|
|
||||||
hidden_states = gate * self.proj_out(hidden_states)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
if hidden_states.dtype == torch.float16:
|
if hidden_states.dtype == torch.float16:
|
||||||
hidden_states = hidden_states.clip(-65504, 65504)
|
hidden_states = hidden_states.clip(-65504, 65504)
|
||||||
|
|
||||||
@@ -113,18 +417,16 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
self.norm1 = AdaLayerNormZero(dim)
|
self.norm1 = AdaLayerNormZero(dim)
|
||||||
self.norm1_context = AdaLayerNormZero(dim)
|
self.norm1_context = AdaLayerNormZero(dim)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = FluxAttention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
cross_attention_dim=None,
|
cross_attention_dim=None,
|
||||||
added_kv_proj_dim=dim,
|
|
||||||
dim_head=attention_head_dim,
|
dim_head=attention_head_dim,
|
||||||
heads=num_attention_heads,
|
heads=num_attention_heads,
|
||||||
out_dim=dim,
|
|
||||||
context_pre_only=False,
|
|
||||||
bias=True,
|
|
||||||
processor=FluxAttnProcessor2_0(),
|
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
|
dropout=0.0,
|
||||||
|
bias=True,
|
||||||
|
added_kv_proj_dim=dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||||
@@ -141,12 +443,13 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||||
|
|
||||||
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||||
encoder_hidden_states, emb=temb
|
encoder_hidden_states, emb=temb
|
||||||
)
|
)
|
||||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
||||||
# Attention.
|
# Attention.
|
||||||
attention_outputs = self.attn(
|
attention_outputs = self.attn(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
@@ -165,7 +468,7 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
hidden_states = hidden_states + attn_output
|
hidden_states = hidden_states + attn_output
|
||||||
|
|
||||||
norm_hidden_states = self.norm2(hidden_states)
|
norm_hidden_states = self.norm2(hidden_states)
|
||||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||||
|
|
||||||
ff_output = self.ff(norm_hidden_states)
|
ff_output = self.ff(norm_hidden_states)
|
||||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||||
@@ -175,23 +478,62 @@ class FluxTransformerBlock(nn.Module):
|
|||||||
hidden_states = hidden_states + ip_attn_output
|
hidden_states = hidden_states + ip_attn_output
|
||||||
|
|
||||||
# Process attention outputs for the `encoder_hidden_states`.
|
# Process attention outputs for the `encoder_hidden_states`.
|
||||||
|
|
||||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||||
|
|
||||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (
|
||||||
|
1 + c_scale_mlp.unsqueeze(1)
|
||||||
|
) + c_shift_mlp.unsqueeze(1)
|
||||||
|
|
||||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||||
|
|
||||||
if encoder_hidden_states.dtype == torch.float16:
|
if encoder_hidden_states.dtype == torch.float16:
|
||||||
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FluxPosEmbed(nn.Module):
|
||||||
|
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||||
|
def __init__(self, theta: int, axes_dim: List[int]):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
n_axes = ids.shape[-1]
|
||||||
|
cos_out = []
|
||||||
|
sin_out = []
|
||||||
|
pos = ids.float()
|
||||||
|
is_mps = ids.device.type == "mps"
|
||||||
|
is_npu = ids.device.type == "npu"
|
||||||
|
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||||
|
for i in range(n_axes):
|
||||||
|
cos, sin = get_1d_rotary_pos_embed(
|
||||||
|
self.axes_dim[i],
|
||||||
|
pos[:, i],
|
||||||
|
theta=self.theta,
|
||||||
|
repeat_interleave_real=True,
|
||||||
|
use_real=True,
|
||||||
|
freqs_dtype=freqs_dtype,
|
||||||
|
)
|
||||||
|
cos_out.append(cos)
|
||||||
|
sin_out.append(sin)
|
||||||
|
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||||
|
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
|
|
||||||
|
|
||||||
class FluxTransformer2DModel(
|
class FluxTransformer2DModel(
|
||||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
|
ModelMixin,
|
||||||
|
ConfigMixin,
|
||||||
|
PeftAdapterMixin,
|
||||||
|
FromOriginalModelMixin,
|
||||||
|
FluxTransformer2DLoadersMixin,
|
||||||
|
CacheMixin,
|
||||||
|
AttentionMixin,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
The Transformer model introduced in Flux.
|
The Transformer model introduced in Flux.
|
||||||
@@ -286,106 +628,6 @@ class FluxTransformer2DModel(
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
@property
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
|
||||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
|
||||||
r"""
|
|
||||||
Returns:
|
|
||||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
|
||||||
indexed by its weight name.
|
|
||||||
"""
|
|
||||||
# set recursively
|
|
||||||
processors = {}
|
|
||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
|
||||||
if hasattr(module, "get_processor"):
|
|
||||||
processors[f"{name}.processor"] = module.get_processor()
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_add_processors(name, module, processors)
|
|
||||||
|
|
||||||
return processors
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
|
||||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
|
||||||
r"""
|
|
||||||
Sets the attention processor to use to compute attention.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
|
||||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
|
||||||
for **all** `Attention` layers.
|
|
||||||
|
|
||||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
|
||||||
processor. This is strongly recommended when setting trainable attention processors.
|
|
||||||
|
|
||||||
"""
|
|
||||||
count = len(self.attn_processors.keys())
|
|
||||||
|
|
||||||
if isinstance(processor, dict) and len(processor) != count:
|
|
||||||
raise ValueError(
|
|
||||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
|
||||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
|
||||||
)
|
|
||||||
|
|
||||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
|
||||||
if hasattr(module, "set_processor"):
|
|
||||||
if not isinstance(processor, dict):
|
|
||||||
module.set_processor(processor)
|
|
||||||
else:
|
|
||||||
module.set_processor(processor.pop(f"{name}.processor"))
|
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
|
||||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
|
||||||
|
|
||||||
for name, module in self.named_children():
|
|
||||||
fn_recursive_attn_processor(name, module, processor)
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
|
||||||
def fuse_qkv_projections(self):
|
|
||||||
"""
|
|
||||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
|
||||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This API is 🧪 experimental.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
"""
|
|
||||||
self.original_attn_processors = None
|
|
||||||
|
|
||||||
for _, attn_processor in self.attn_processors.items():
|
|
||||||
if "Added" in str(attn_processor.__class__.__name__):
|
|
||||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
|
||||||
|
|
||||||
self.original_attn_processors = self.attn_processors
|
|
||||||
|
|
||||||
for module in self.modules():
|
|
||||||
if isinstance(module, Attention):
|
|
||||||
module.fuse_projections(fuse=True)
|
|
||||||
|
|
||||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
|
||||||
def unfuse_qkv_projections(self):
|
|
||||||
"""Disables the fused QKV projection if enabled.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This API is 🧪 experimental.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
"""
|
|
||||||
if self.original_attn_processors is not None:
|
|
||||||
self.set_attn_processor(self.original_attn_processors)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -479,11 +721,7 @@ class FluxTransformer2DModel(
|
|||||||
for index_block, block in enumerate(self.transformer_blocks):
|
for index_block, block in enumerate(self.transformer_blocks):
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||||
block,
|
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states,
|
|
||||||
temb,
|
|
||||||
image_rotary_emb,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -510,12 +748,7 @@ class FluxTransformer2DModel(
|
|||||||
|
|
||||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||||
hidden_states = self._gradient_checkpointing_func(
|
hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb, image_rotary_emb)
|
||||||
block,
|
|
||||||
hidden_states,
|
|
||||||
temb,
|
|
||||||
image_rotary_emb,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
hidden_states = block(
|
hidden_states = block(
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ from .import_utils import (
|
|||||||
is_bitsandbytes_version,
|
is_bitsandbytes_version,
|
||||||
is_bs4_available,
|
is_bs4_available,
|
||||||
is_cosmos_guardrail_available,
|
is_cosmos_guardrail_available,
|
||||||
|
is_flash_attn_3_available,
|
||||||
|
is_flash_attn_available,
|
||||||
|
is_flash_attn_version,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_ftfy_available,
|
is_ftfy_available,
|
||||||
is_gguf_available,
|
is_gguf_available,
|
||||||
@@ -90,6 +93,8 @@ from .import_utils import (
|
|||||||
is_peft_version,
|
is_peft_version,
|
||||||
is_pytorch_retinaface_available,
|
is_pytorch_retinaface_available,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
|
is_sageattention_available,
|
||||||
|
is_sageattention_version,
|
||||||
is_scipy_available,
|
is_scipy_available,
|
||||||
is_sentencepiece_available,
|
is_sentencepiece_available,
|
||||||
is_tensorboard_available,
|
is_tensorboard_available,
|
||||||
@@ -108,6 +113,7 @@ from .import_utils import (
|
|||||||
is_unidecode_available,
|
is_unidecode_available,
|
||||||
is_wandb_available,
|
is_wandb_available,
|
||||||
is_xformers_available,
|
is_xformers_available,
|
||||||
|
is_xformers_version,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
|||||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||||
DIFFUSERS_REQUEST_TIMEOUT = 60
|
DIFFUSERS_REQUEST_TIMEOUT = 60
|
||||||
|
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||||
|
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||||
|
|
||||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||||
|
|||||||
@@ -219,6 +219,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
|
|||||||
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
||||||
_nltk_available, _nltk_version = _is_package_available("nltk")
|
_nltk_available, _nltk_version = _is_package_available("nltk")
|
||||||
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
||||||
|
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
||||||
|
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
||||||
|
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
||||||
|
|
||||||
|
|
||||||
def is_torch_available():
|
def is_torch_available():
|
||||||
@@ -377,6 +380,18 @@ def is_hpu_available():
|
|||||||
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
||||||
|
|
||||||
|
|
||||||
|
def is_sageattention_available():
|
||||||
|
return _sageattention_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn_available():
|
||||||
|
return _flash_attn_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn_3_available():
|
||||||
|
return _flash_attn_3_available
|
||||||
|
|
||||||
|
|
||||||
# docstyle-ignore
|
# docstyle-ignore
|
||||||
FLAX_IMPORT_ERROR = """
|
FLAX_IMPORT_ERROR = """
|
||||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||||
@@ -803,6 +818,51 @@ def is_optimum_quanto_version(operation: str, version: str):
|
|||||||
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
|
def is_xformers_version(operation: str, version: str):
|
||||||
|
"""
|
||||||
|
Compares the current xformers version to a given reference with an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`
|
||||||
|
version (`str`):
|
||||||
|
A version string
|
||||||
|
"""
|
||||||
|
if not _xformers_available:
|
||||||
|
return False
|
||||||
|
return compare_versions(parse(_xformers_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
|
def is_sageattention_version(operation: str, version: str):
|
||||||
|
"""
|
||||||
|
Compares the current sageattention version to a given reference with an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`
|
||||||
|
version (`str`):
|
||||||
|
A version string
|
||||||
|
"""
|
||||||
|
if not _sageattention_available:
|
||||||
|
return False
|
||||||
|
return compare_versions(parse(_sageattention_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
|
def is_flash_attn_version(operation: str, version: str):
|
||||||
|
"""
|
||||||
|
Compares the current flash-attention version to a given reference with an operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation (`str`):
|
||||||
|
A string representation of an operator, such as `">"` or `"<="`
|
||||||
|
version (`str`):
|
||||||
|
A version string
|
||||||
|
"""
|
||||||
|
if not _flash_attn_available:
|
||||||
|
return False
|
||||||
|
return compare_versions(parse(_flash_attn_version), operation, version)
|
||||||
|
|
||||||
|
|
||||||
def get_objects_from_module(module):
|
def get_objects_from_module(module):
|
||||||
"""
|
"""
|
||||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||||
|
|||||||
Reference in New Issue
Block a user