mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
custom-cod
...
mem-eff-fu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eacdf142b4 |
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -20,7 +21,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||
|
||||
@@ -36,6 +37,10 @@ if is_xformers_available():
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class Attention(nn.Module):
|
||||
@@ -663,9 +668,12 @@ class Attention(nn.Module):
|
||||
return encoder_hidden_states
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
def fuse_projections(self, fuse=True, low_cpu_mem_usage=True):
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
init_ctx = init_empty_weights if is_accelerate_available() and low_cpu_mem_usage else nullcontext
|
||||
if is_accelerate_available():
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# fetch weight matrices.
|
||||
@@ -674,22 +682,61 @@ class Attention(nn.Module):
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
# create a new single projection layer and copy over the weights.
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
with init_ctx():
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
self.to_qkv, "weight", device, value=concatenated_weights, dtype=dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_qkv, "weight", device, value=concatenated_weights)
|
||||
else:
|
||||
self.to_qkv.weight.copy_(concatenated_weights)
|
||||
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias)
|
||||
else:
|
||||
self.to_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
del self.to_q
|
||||
del self.to_k
|
||||
del self.to_v
|
||||
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
with init_ctx():
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights)
|
||||
else:
|
||||
self.to_kv.weight.copy_(concatenated_weights)
|
||||
|
||||
if self.use_bias:
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias)
|
||||
else:
|
||||
self.to_kv.bias.copy_(concatenated_bias)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
del self.to_k
|
||||
del self.to_v
|
||||
|
||||
# handle added projections for SD3 and others.
|
||||
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
|
||||
@@ -699,15 +746,38 @@ class Attention(nn.Module):
|
||||
in_features = concatenated_weights.shape[1]
|
||||
out_features = concatenated_weights.shape[0]
|
||||
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
with init_ctx():
|
||||
self.to_added_qkv = nn.Linear(
|
||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||
)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
self.to_added_qkv, "weight", device, value=concatenated_weights, dtype=dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_added_qkv, "weight", device, value=concatenated_weights)
|
||||
else:
|
||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||
|
||||
if self.added_proj_bias:
|
||||
concatenated_bias = torch.cat(
|
||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||
)
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
if is_accelerate_available() and low_cpu_mem_usage:
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
self.to_added_qkv, "bias", device, value=concatenated_bias, dtype=dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(self.to_added_qkv, "bias", device, value=concatenated_bias)
|
||||
else:
|
||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
del self.add_q_proj
|
||||
del self.add_k_proj
|
||||
del self.add_v_proj
|
||||
|
||||
self.fused_projections = fuse
|
||||
|
||||
@@ -1770,6 +1840,71 @@ class FluxSingleAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxSingleAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedFluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, _, _ = hidden_states.shape
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -1868,6 +2003,110 @@ class FluxAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# `context` projections.
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
Reference in New Issue
Block a user