mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 02:44:53 +08:00
Compare commits
70 Commits
dynamic-up
...
sdxl/feat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad9c0c2989 | ||
|
|
8947e55efe | ||
|
|
f94a4ef428 | ||
|
|
d48bd10eae | ||
|
|
c9f7ed4350 | ||
|
|
ce9fa3c732 | ||
|
|
7618b3575a | ||
|
|
84117cae95 | ||
|
|
95167589e7 | ||
|
|
e65ddcd08c | ||
|
|
d485abdd27 | ||
|
|
abf9ebc766 | ||
|
|
c6d5e86a00 | ||
|
|
8fadb14a96 | ||
|
|
a5fb4d761d | ||
|
|
d17bbbd901 | ||
|
|
8d17831bf8 | ||
|
|
a7a952dc11 | ||
|
|
93b5f92a60 | ||
|
|
ff28fdd884 | ||
|
|
7d8b91300c | ||
|
|
0432297da6 | ||
|
|
d944d8b108 | ||
|
|
2632a8b272 | ||
|
|
4f882ab0df | ||
|
|
1688fee353 | ||
|
|
8c0c3e2f30 | ||
|
|
418d33c7f7 | ||
|
|
44d4263ac4 | ||
|
|
253aaf0d5d | ||
|
|
4e120d86ca | ||
|
|
8da35af8d0 | ||
|
|
6c5712cdbe | ||
|
|
4b66d10240 | ||
|
|
e0848ebbde | ||
|
|
2c02f0730d | ||
|
|
0afc2b455f | ||
|
|
be647c3419 | ||
|
|
981dc3abfa | ||
|
|
c7f78bf54c | ||
|
|
b64e533607 | ||
|
|
e51bc7e744 | ||
|
|
23f8404bb2 | ||
|
|
ba14a08235 | ||
|
|
7b1688873f | ||
|
|
5175b91ccb | ||
|
|
32012cea11 | ||
|
|
a0b9066244 | ||
|
|
06bb65b107 | ||
|
|
580a1c2dc2 | ||
|
|
678577b920 | ||
|
|
94fb74a7f8 | ||
|
|
a7da467125 | ||
|
|
c4eaec3ae6 | ||
|
|
01c6038e9d | ||
|
|
a030797ff1 | ||
|
|
86027e52fd | ||
|
|
4e556a92cf | ||
|
|
c5a5f85293 | ||
|
|
a4da76b67c | ||
|
|
f5b091d1be | ||
|
|
88c7e16693 | ||
|
|
bd855d78b3 | ||
|
|
096fffbb65 | ||
|
|
ff04934d41 | ||
|
|
75ae3df500 | ||
|
|
215bf3b667 | ||
|
|
55f1842ad3 | ||
|
|
afb517ae61 | ||
|
|
bf4e645a77 |
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
|
|||||||
## AttnProcessor2_0
|
## AttnProcessor2_0
|
||||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||||
|
|
||||||
|
## FusedAttnProcessor2_0
|
||||||
|
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||||
|
|
||||||
## LoRAAttnProcessor
|
## LoRAAttnProcessor
|
||||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||||
|
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ if is_torch_available():
|
|||||||
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
|
||||||
_import_structure["embeddings"] = ["ImageProjection"]
|
_import_structure["embeddings"] = ["ImageProjection"]
|
||||||
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||||
|
|||||||
@@ -113,12 +113,14 @@ class Attention(nn.Module):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||||
|
self.query_dim = query_dim
|
||||||
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||||
self.upcast_attention = upcast_attention
|
self.upcast_attention = upcast_attention
|
||||||
self.upcast_softmax = upcast_softmax
|
self.upcast_softmax = upcast_softmax
|
||||||
self.rescale_output_factor = rescale_output_factor
|
self.rescale_output_factor = rescale_output_factor
|
||||||
self.residual_connection = residual_connection
|
self.residual_connection = residual_connection
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.fused_projections = False
|
||||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||||
|
|
||||||
# we make use of this private variable to know whether this class is loaded
|
# we make use of this private variable to know whether this class is loaded
|
||||||
@@ -180,6 +182,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
linear_cls = LoRACompatibleLinear
|
linear_cls = LoRACompatibleLinear
|
||||||
|
|
||||||
|
self.linear_cls = linear_cls
|
||||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||||
|
|
||||||
if not self.only_cross_attention:
|
if not self.only_cross_attention:
|
||||||
@@ -692,6 +695,32 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
return encoder_hidden_states
|
return encoder_hidden_states
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def fuse_projections(self, fuse=True):
|
||||||
|
is_cross_attention = self.cross_attention_dim != self.query_dim
|
||||||
|
device = self.to_q.weight.data.device
|
||||||
|
dtype = self.to_q.weight.data.dtype
|
||||||
|
|
||||||
|
if not is_cross_attention:
|
||||||
|
# fetch weight matrices.
|
||||||
|
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||||
|
in_features = concatenated_weights.shape[1]
|
||||||
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
|
# create a new single projection layer and copy over the weights.
|
||||||
|
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_qkv.weight.copy_(concatenated_weights)
|
||||||
|
|
||||||
|
else:
|
||||||
|
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||||
|
in_features = concatenated_weights.shape[1]
|
||||||
|
out_features = concatenated_weights.shape[0]
|
||||||
|
|
||||||
|
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
|
||||||
|
self.to_kv.weight.copy_(concatenated_weights)
|
||||||
|
|
||||||
|
self.fused_projections = fuse
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor:
|
class AttnProcessor:
|
||||||
r"""
|
r"""
|
||||||
@@ -1184,9 +1213,6 @@ class AttnProcessor2_0:
|
|||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
args = () if USE_PEFT_BACKEND else (scale,)
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
@@ -1253,6 +1279,103 @@ class AttnProcessor2_0:
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAttnProcessor2_0:
|
||||||
|
r"""
|
||||||
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||||
|
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is currently 🧪 experimental in nature and can change in future.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
|
raise ImportError(
|
||||||
|
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
attn: Attention,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
residual = hidden_states
|
||||||
|
if attn.spatial_norm is not None:
|
||||||
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||||
|
|
||||||
|
input_ndim = hidden_states.ndim
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = (
|
||||||
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||||
|
# scaled_dot_product_attention expects attention_mask shape to be
|
||||||
|
# (batch, heads, source_length, target_length)
|
||||||
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||||
|
|
||||||
|
if attn.group_norm is not None:
|
||||||
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||||
|
|
||||||
|
args = () if USE_PEFT_BACKEND else (scale,)
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
qkv = attn.to_qkv(hidden_states, *args)
|
||||||
|
split_size = qkv.shape[-1] // 3
|
||||||
|
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||||
|
else:
|
||||||
|
if attn.norm_cross:
|
||||||
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
query = attn.to_q(hidden_states, *args)
|
||||||
|
|
||||||
|
kv = attn.to_kv(encoder_hidden_states, *args)
|
||||||
|
split_size = kv.shape[-1] // 2
|
||||||
|
key, value = torch.split(kv, split_size, dim=-1)
|
||||||
|
|
||||||
|
inner_dim = key.shape[-1]
|
||||||
|
head_dim = inner_dim // attn.heads
|
||||||
|
|
||||||
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
|
hidden_states = F.scaled_dot_product_attention(
|
||||||
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
if input_ndim == 4:
|
||||||
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
|
if attn.residual_connection:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
|
||||||
@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = (
|
|||||||
AttentionProcessor = Union[
|
AttentionProcessor = Union[
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
|
FusedAttnProcessor2_0,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
|
|||||||
from .attention_processor import (
|
from .attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
CROSS_ATTENTION_PROCESSORS,
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
Attention,
|
||||||
AttentionProcessor,
|
AttentionProcessor,
|
||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
|||||||
return (dec,)
|
return (dec,)
|
||||||
|
|
||||||
return DecoderOutput(sample=dec)
|
return DecoderOutput(sample=dec)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from .activations import get_activation
|
|||||||
from .attention_processor import (
|
from .attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
CROSS_ATTENTION_PROCESSORS,
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
Attention,
|
||||||
AttentionProcessor,
|
AttentionProcessor,
|
||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||||
setattr(upsample_block, k, None)
|
setattr(upsample_block, k, None)
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ...loaders import (
|
|||||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||||
from ...models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
|
FusedAttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline(
|
|||||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||||
return add_time_ids
|
return add_time_ids
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
|
||||||
def upcast_vae(self):
|
def upcast_vae(self):
|
||||||
dtype = self.vae.dtype
|
dtype = self.vae.dtype
|
||||||
self.vae.to(dtype=torch.float32)
|
self.vae.to(dtype=torch.float32)
|
||||||
@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline(
|
|||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
|
FusedAttnProcessor2_0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline(
|
|||||||
"""Disables the FreeU mechanism if enabled."""
|
"""Disables the FreeU mechanism if enabled."""
|
||||||
self.unet.disable_freeu()
|
self.unet.disable_freeu()
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||||
|
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||||
|
"""
|
||||||
|
self.fusing_unet = False
|
||||||
|
self.fusing_vae = False
|
||||||
|
|
||||||
|
if unet:
|
||||||
|
self.fusing_unet = True
|
||||||
|
self.unet.fuse_qkv_projections()
|
||||||
|
self.unet.set_attn_processor(FusedAttnProcessor2_0())
|
||||||
|
|
||||||
|
if vae:
|
||||||
|
if not isinstance(self.vae, AutoencoderKL):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
|
||||||
|
|
||||||
|
self.fusing_vae = True
|
||||||
|
self.vae.fuse_qkv_projections()
|
||||||
|
self.vae.set_attn_processor(FusedAttnProcessor2_0())
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
||||||
|
"""Disable QKV projection fusion if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
|
||||||
|
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if unet:
|
||||||
|
if not self.fusing_unet:
|
||||||
|
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
|
||||||
|
else:
|
||||||
|
self.unet.unfuse_qkv_projections()
|
||||||
|
self.fusing_unet = False
|
||||||
|
|
||||||
|
if vae:
|
||||||
|
if not self.fusing_vae:
|
||||||
|
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
|
||||||
|
else:
|
||||||
|
self.vae.unfuse_qkv_projections()
|
||||||
|
self.fusing_vae = False
|
||||||
|
|
||||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
|
|||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
AttnProcessor2_0,
|
AttnProcessor2_0,
|
||||||
|
FusedAttnProcessor2_0,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
|
|||||||
XFormersAttnProcessor,
|
XFormersAttnProcessor,
|
||||||
LoRAXFormersAttnProcessor,
|
LoRAXFormersAttnProcessor,
|
||||||
LoRAAttnProcessor2_0,
|
LoRAAttnProcessor2_0,
|
||||||
|
FusedAttnProcessor2_0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
# if xformers or torch_2_0 is used attention block does not need
|
# if xformers or torch_2_0 is used attention block does not need
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
|
|||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...models import ModelMixin
|
from ...models import ModelMixin
|
||||||
from ...models.activations import get_activation
|
from ...models.activations import get_activation
|
||||||
from ...models.attention import Attention
|
|
||||||
from ...models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
CROSS_ATTENTION_PROCESSORS,
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
Attention,
|
||||||
AttentionProcessor,
|
AttentionProcessor,
|
||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
AttnAddedKVProcessor2_0,
|
AttnAddedKVProcessor2_0,
|
||||||
@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||||||
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||||
setattr(upsample_block, k, None)
|
setattr(upsample_block, k, None)
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
|
|||||||
@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
@property
|
@property
|
||||||
def init_noise_sigma(self):
|
def init_noise_sigma(self):
|
||||||
# standard deviation of the initial noise distribution
|
# standard deviation of the initial noise distribution
|
||||||
|
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
|
||||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||||
return self.sigmas.max()
|
return max_sigma
|
||||||
|
|
||||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
return (max_sigma**2 + 1) ** 0.5
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
|
||||||
|
|
||||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||||
|
if sigmas.device.type == "cuda":
|
||||||
|
self.sigmas = self.sigmas.tolist()
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
|
|
||||||
def _sigma_to_t(self, sigma, log_sigmas):
|
def _sigma_to_t(self, sigma, log_sigmas):
|
||||||
|
|||||||
@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests(
|
|||||||
|
|
||||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||||
|
|
||||||
|
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
components = self.get_dummy_components()
|
||||||
|
sd_pipe = StableDiffusionXLPipeline(**components)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = sd_pipe(**inputs).images
|
||||||
|
original_image_slice = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
sd_pipe.fuse_qkv_projections()
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = sd_pipe(**inputs).images
|
||||||
|
image_slice_fused = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
sd_pipe.unfuse_qkv_projections()
|
||||||
|
inputs = self.get_dummy_inputs(device)
|
||||||
|
image = sd_pipe(**inputs).images
|
||||||
|
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
|
||||||
|
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||||
|
assert np.allclose(
|
||||||
|
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||||
|
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||||
|
assert np.allclose(
|
||||||
|
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||||
|
), "Original outputs should match when fused QKV projections are disabled."
|
||||||
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user