Compare commits

...

20 Commits

Author SHA1 Message Date
Sayak Paul
d7b91c745c Merge branch 'main' into fuse-projections-pixart 2023-12-27 20:22:51 +05:30
sayakpaul
e62b209d43 check again 2023-12-27 09:54:54 +05:30
sayakpaul
ca50e6558f bring 2023-12-27 09:44:47 +05:30
sayakpaul
25f687cb69 checkin' 2023-12-27 09:39:18 +05:30
sayakpaul
a64512aedd print 2023-12-27 09:33:48 +05:30
sayakpaul
ea517c4edd remove comments. 2023-12-27 09:27:07 +05:30
sayakpaul
7db10a9279 remove _explicitly_mark_as_cross_attention 2023-12-27 08:44:04 +05:30
Sayak Paul
b50688b2f9 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-27 08:36:12 +05:30
sayakpaul
31ad79d9e1 merge main and resolve conflicts 2023-12-24 19:20:16 +05:30
sayakpaul
38ada6a526 remove prints. 2023-12-15 11:42:13 +05:30
sayakpaul
afd6df57ed remove prints. 2023-12-15 11:38:43 +05:30
sayakpaul
f2577f5a98 fix 2023-12-15 11:37:42 +05:30
sayakpaul
85612d7b04 enable using same dim for cross-attention 2023-12-15 11:31:22 +05:30
sayakpaul
f8778c8f92 debug 2023-12-15 11:22:46 +05:30
sayakpaul
1f920408d8 debug 2023-12-15 10:29:15 +05:30
sayakpaul
13c1e5e050 debug 2023-12-15 10:25:17 +05:30
sayakpaul
4e429c4197 debug 2023-12-15 10:20:30 +05:30
sayakpaul
d33a03caa7 debug 2023-12-15 10:16:59 +05:30
sayakpaul
d5d77e2bac debug 2023-12-15 10:12:43 +05:30
sayakpaul
a10bea8317 feat: projection fusion for pixart alpha 2023-12-15 10:03:38 +05:30
3 changed files with 204 additions and 5 deletions

View File

@@ -113,6 +113,7 @@ class Attention(nn.Module):
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.is_cross_attention = cross_attention_dim is not None
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
@@ -697,7 +698,7 @@ class Attention(nn.Module):
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
is_cross_attention = self.is_cross_attention
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
@@ -1331,10 +1332,15 @@ class FusedAttnProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
print(f"Cross attention: {encoder_hidden_states is not None}, under fusion.")
# qkv = attn.to_qkv(hidden_states, *args)
# split_size = qkv.shape[-1] // 3
# query, key, value = torch.split(qkv, split_size, dim=-1)
query = attn.to_q(hidden_states, *args)
key = attn.to_k(hidden_states, *args)
value = attn.to_v(hidden_states, *args)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
@@ -1344,6 +1350,13 @@ class FusedAttnProcessor2_0:
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
# else:
# query = attn.to_q(hidden_states, *args)
# if attn.norm_cross:
# encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states, *args)
# value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
import torch
import torch.nn.functional as F
@@ -22,6 +22,14 @@ from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
@@ -243,6 +251,44 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
@@ -457,3 +503,81 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
return (output,)
return Transformer2DModelOutput(sample=output)
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)

View File

@@ -24,6 +24,7 @@ from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, Transformer2DModel
from ...models.attention_processor import FusedAttnProcessor2_0
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -663,6 +664,67 @@ class PixArtAlphaPipeline(DiffusionPipeline):
return samples
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections with unet->transformer, UNet->Transformer
def fuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_transformer = False
self.fusing_vae = False
if transformer:
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
self.transformer.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections with unet->transformer, UNet->Transformer
def unfuse_qkv_projections(self, transformer: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
transformer (`bool`, defaults to `True`): To apply fusion on the Transformer.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if transformer:
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(