Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
eacdf142b4 add memory efficient fusion (wip) 2024-08-16 09:36:08 +05:30

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math import math
from contextlib import nullcontext
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
@@ -20,7 +21,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..image_processor import IPAdapterMaskProcessor from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging from ..utils import deprecate, is_accelerate_available, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
@@ -36,6 +37,10 @@ if is_xformers_available():
else: else:
xformers = None xformers = None
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
@maybe_allow_in_graph @maybe_allow_in_graph
class Attention(nn.Module): class Attention(nn.Module):
@@ -663,9 +668,12 @@ class Attention(nn.Module):
return encoder_hidden_states return encoder_hidden_states
@torch.no_grad() @torch.no_grad()
def fuse_projections(self, fuse=True): def fuse_projections(self, fuse=True, low_cpu_mem_usage=True):
device = self.to_q.weight.data.device device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype dtype = self.to_q.weight.data.dtype
init_ctx = init_empty_weights if is_accelerate_available() and low_cpu_mem_usage else nullcontext
if is_accelerate_available():
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if not self.is_cross_attention: if not self.is_cross_attention:
# fetch weight matrices. # fetch weight matrices.
@@ -674,22 +682,61 @@ class Attention(nn.Module):
out_features = concatenated_weights.shape[0] out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights. # create a new single projection layer and copy over the weights.
self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) with init_ctx():
self.to_qkv.weight.copy_(concatenated_weights) self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(
self.to_qkv, "weight", device, value=concatenated_weights, dtype=dtype
)
else:
set_module_tensor_to_device(self.to_qkv, "weight", device, value=concatenated_weights)
else:
self.to_qkv.weight.copy_(concatenated_weights)
if self.use_bias: if self.use_bias:
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data]) concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
self.to_qkv.bias.copy_(concatenated_bias) if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias, dtype=dtype)
else:
set_module_tensor_to_device(self.to_qkv, "bias", device, value=concatenated_bias)
else:
self.to_qkv.bias.copy_(concatenated_bias)
if low_cpu_mem_usage:
del self.to_q
del self.to_k
del self.to_v
else: else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data]) concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1] in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0] out_features = concatenated_weights.shape[0]
self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype) with init_ctx():
self.to_kv.weight.copy_(concatenated_weights) self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights, dtype=dtype)
else:
set_module_tensor_to_device(self.to_kv, "weight", device, value=concatenated_weights)
else:
self.to_kv.weight.copy_(concatenated_weights)
if self.use_bias: if self.use_bias:
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data]) concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
self.to_kv.bias.copy_(concatenated_bias) if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias, dtype=dtype)
else:
set_module_tensor_to_device(self.to_kv, "bias", device, value=concatenated_bias)
else:
self.to_kv.bias.copy_(concatenated_bias)
if low_cpu_mem_usage:
del self.to_k
del self.to_v
# handle added projections for SD3 and others. # handle added projections for SD3 and others.
if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"): if hasattr(self, "add_q_proj") and hasattr(self, "add_k_proj") and hasattr(self, "add_v_proj"):
@@ -699,15 +746,38 @@ class Attention(nn.Module):
in_features = concatenated_weights.shape[1] in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0] out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear( with init_ctx():
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype self.to_added_qkv = nn.Linear(
) in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
self.to_added_qkv.weight.copy_(concatenated_weights) )
if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(
self.to_added_qkv, "weight", device, value=concatenated_weights, dtype=dtype
)
else:
set_module_tensor_to_device(self.to_added_qkv, "weight", device, value=concatenated_weights)
else:
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias: if self.added_proj_bias:
concatenated_bias = torch.cat( concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data] [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
) )
self.to_added_qkv.bias.copy_(concatenated_bias) if is_accelerate_available() and low_cpu_mem_usage:
if accepts_dtype:
set_module_tensor_to_device(
self.to_added_qkv, "bias", device, value=concatenated_bias, dtype=dtype
)
else:
set_module_tensor_to_device(self.to_added_qkv, "bias", device, value=concatenated_bias)
else:
self.to_added_qkv.bias.copy_(concatenated_bias)
if low_cpu_mem_usage:
del self.add_q_proj
del self.add_k_proj
del self.add_v_proj
self.fused_projections = fuse self.fused_projections = fuse
@@ -1770,6 +1840,71 @@ class FluxSingleAttnProcessor2_0:
return hidden_states return hidden_states
class FusedFluxSingleAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxSingleAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, _, _ = hidden_states.shape
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states
class FluxAttnProcessor2_0: class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections.""" """Attention processor used typically in processing the SD3-like self-attention projections."""
@@ -1868,6 +2003,110 @@ class FluxAttnProcessor2_0:
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
class FusedFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size = encoder_hidden_states.shape[0]
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# `context` projections.
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
# YiYi to-do: update uising apply_rotary_emb
# from ..embeddings import apply_rotary_emb
# query = apply_rotary_emb(query, image_rotary_emb)
# key = apply_rotary_emb(key, image_rotary_emb)
query, key = apply_rope(query, key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
return hidden_states, encoder_hidden_states
class XFormersAttnAddedKVProcessor: class XFormersAttnAddedKVProcessor:
r""" r"""
Processor for implementing memory efficient attention using xFormers. Processor for implementing memory efficient attention using xFormers.