mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
1 Commits
memory-opt
...
mem-eff-fu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
eacdf142b4 |
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -20,7 +21,7 @@ import torch.nn.functional as F
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ..image_processor import IPAdapterMaskProcessor
|
from ..image_processor import IPAdapterMaskProcessor
|
||||||
from ..utils import deprecate, logging
|
from ..utils import deprecate, is_accelerate_available, logging
|
||||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||||
|
|
||||||
@@ -36,6 +37,10 @@ if is_xformers_available():
|
|||||||
else:
|
else:
|
||||||
xformers = None
|
xformers = None
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import set_module_tensor_to_device
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@@ -663,9 +668,12 @@ class Attention(nn.Module):
|
|||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def fuse_projections(self, fuse=True):
|
def fuse_projections(self, fuse=True, low_cpu_mem_usage=True):
|
||||||
device = self.to_q.weight.data.device
|
device = self.to_q.weight.data.device
|
||||||
dtype = self.to_q.weight.data.dtype
|
dtype = self.to_q.weight.data.dtype
|
||||||
|
init_ctx = init_empty_weights if is_accelerate_available() and low_cpu_mem_usage else nullcontext
|
||||||
|
if is_accelerate_available():
|
||||||
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||||
|
|
||||||
if not self.is_cross_attention:
|
if not self.is_cross_attention:
|
||||||
# fetch weight matrices.
|
# fetch weight matrices.
|
||||||
@@ -674,22 +682,61 @@ class Attention(nn.Module):
|
|||||||
out_features = concatenated_weights.shape[0]
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
# create a new single projection layer and copy over the weights.
|
# create a new single projection layer and copy over the weights.
|
||||||
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
with init_ctx():
|
||||||
self.to_qkv.weight.copy_(concatenated_weights)
|
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||||
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(
|
||||||
|
self.to_qkv, "weight", device, value=concatenated_weights, dtype=dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_qkv, "weight", device, value=concatenated_weights)
|
||||||
|
else:
|
||||||
|
self.to_qkv.weight.copy_(concatenated_weights)
|
||||||
|
|
||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||||
self.to_qkv.bias.copy_(concatenated_bias)
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias, dtype=dtype)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias)
|
||||||
|
else:
|
||||||
|
self.to_qkv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
del self.to_q
|
||||||
|
del self.to_k
|
||||||
|
del self.to_v
|
||||||
|
|
||||||
else:
|
else:
|
||||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||||
in_features = concatenated_weights.shape[1]
|
in_features = concatenated_weights.shape[1]
|
||||||
out_features = concatenated_weights.shape[0]
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
with init_ctx():
|
||||||
self.to_kv.weight.copy_(concatenated_weights)
|
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
|
||||||
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights, dtype=dtype)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights)
|
||||||
|
else:
|
||||||
|
self.to_kv.weight.copy_(concatenated_weights)
|
||||||
|
|
||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||||
self.to_kv.bias.copy_(concatenated_bias)
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias, dtype=dtype)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias)
|
||||||
|
else:
|
||||||
|
self.to_kv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
del self.to_k
|
||||||
|
del self.to_v
|
||||||
|
|
||||||
# handle added projections for SD3 and others.
|
# handle added projections for SD3 and others.
|
||||||
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
|
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
|
||||||
@@ -699,15 +746,38 @@ class Attention(nn.Module):
|
|||||||
in_features = concatenated_weights.shape[1]
|
in_features = concatenated_weights.shape[1]
|
||||||
out_features = concatenated_weights.shape[0]
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
self.to_added_qkv = nn.Linear(
|
with init_ctx():
|
||||||
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
self.to_added_qkv = nn.Linear(
|
||||||
)
|
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
|
||||||
self.to_added_qkv.weight.copy_(concatenated_weights)
|
)
|
||||||
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(
|
||||||
|
self.to_added_qkv, "weight", device, value=concatenated_weights, dtype=dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_added_qkv, "weight", device, value=concatenated_weights)
|
||||||
|
else:
|
||||||
|
self.to_added_qkv.weight.copy_(concatenated_weights)
|
||||||
|
|
||||||
if self.added_proj_bias:
|
if self.added_proj_bias:
|
||||||
concatenated_bias = torch.cat(
|
concatenated_bias = torch.cat(
|
||||||
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
|
||||||
)
|
)
|
||||||
self.to_added_qkv.bias.copy_(concatenated_bias)
|
if is_accelerate_available() and low_cpu_mem_usage:
|
||||||
|
if accepts_dtype:
|
||||||
|
set_module_tensor_to_device(
|
||||||
|
self.to_added_qkv, "bias", device, value=concatenated_bias, dtype=dtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
set_module_tensor_to_device(self.to_added_qkv, "bias", device, value=concatenated_bias)
|
||||||
|
else:
|
||||||
|
self.to_added_qkv.bias.copy_(concatenated_bias)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
del self.add_q_proj
|
||||||
|
del self.add_k_proj
|
||||||
|
del self.add_v_proj
|
||||||
|
|
||||||
self.fused_projections = fuse
|
self.fused_projections = fuse
|
||||||
|
|
||||||
@@ -1770,6 +1840,71 @@ class FluxSingleAttnProcessor2_0:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FusedFluxSingleAttnProcessor2_0:
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"FusedFluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, _, _ = hidden_states.shape
|
||||||
|
|
||||||
|
qkv = attn.to_qkv(hidden_states)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# Apply RoPE if needed
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
# YiYi to-do: update uising apply_rotary_emb
|
||||||
|
# from ..embeddings import apply_rotary_emb
|
||||||
|
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
query, key = apply_rope(query, key, image_rotary_emb)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class FluxAttnProcessor2_0:
|
class FluxAttnProcessor2_0:
|
||||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||||
|
|
||||||
@@ -1868,6 +2003,110 @@ class FluxAttnProcessor2_0:
|
|||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FusedFluxAttnProcessor2_0:
|
||||||
|
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: torch.FloatTensor = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
context_input_ndim = encoder_hidden_states.ndim
|
||||||
|
if context_input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||||
|
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
# `sample` projections.
|
||||||
|
qkv = attn.to_qkv(hidden_states)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_q is not None:
|
||||||
|
query = attn.norm_q(query)
|
||||||
|
if attn.norm_k is not None:
|
||||||
|
key = attn.norm_k(key)
|
||||||
|
|
||||||
|
# `context` projections.
|
||||||
|
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||||
|
split_size = encoder_qkv.shape[-1] // 3
|
||||||
|
(
|
||||||
|
encoder_hidden_states_query_proj,
|
||||||
|
encoder_hidden_states_key_proj,
|
||||||
|
encoder_hidden_states_value_proj,
|
||||||
|
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||||
|
|
||||||
|
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||||
|
batch_size, -1, attn.heads, head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
if attn.norm_added_q is not None:
|
||||||
|
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||||
|
if attn.norm_added_k is not None:
|
||||||
|
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||||
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||||
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||||
|
|
||||||
|
if image_rotary_emb is not None:
|
||||||
|
# YiYi to-do: update uising apply_rotary_emb
|
||||||
|
# from ..embeddings import apply_rotary_emb
|
||||||
|
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||||
|
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||||
|
query, key = apply_rope(query, key, image_rotary_emb)
|
||||||
|
|
||||||
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
encoder_hidden_states, hidden_states = (
|
||||||
|
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||||
|
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
if context_input_ndim == 4:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class XFormersAttnAddedKVProcessor:
|
class XFormersAttnAddedKVProcessor:
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing memory efficient attention using xFormers.
|
Processor for implementing memory efficient attention using xFormers.
|
||||||
|
|||||||
Reference in New Issue
Block a user