Compare commits

...

70 Commits

Author SHA1 Message Date
sayakpaul
ad9c0c2989 remove printing 2023-12-09 11:01:15 +05:30
sayakpaul
8947e55efe print some sizes and shapes. 2023-12-09 10:46:01 +05:30
sayakpaul
f94a4ef428 remove print 2023-12-08 00:29:29 +05:30
sayakpaul
d48bd10eae print some info 2023-12-08 00:18:14 +05:30
sayakpaul
c9f7ed4350 remove funcs 2023-12-07 13:55:36 +05:30
sayakpaul
ce9fa3c732 fix 2023-12-07 12:57:40 +05:30
sayakpaul
7618b3575a fix 2023-12-07 12:56:20 +05:30
sayakpaul
84117cae95 _pointwise_conv_module_names 2023-12-07 12:52:44 +05:30
sayakpaul
95167589e7 feat: run pointwise convs with linear. 2023-12-07 12:28:55 +05:30
Sayak Paul
e65ddcd08c Merge branch 'main' into sdxl/feat 2023-12-05 15:19:17 +05:30
Sayak Paul
d485abdd27 Merge branch 'main' into sdxl/feat 2023-12-04 19:35:55 +05:30
Sayak Paul
abf9ebc766 Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-04 18:25:29 +05:30
sayakpaul
c6d5e86a00 fix more 2023-12-04 16:55:54 +05:30
sayakpaul
8fadb14a96 fix conditions. 2023-12-04 16:53:28 +05:30
sayakpaul
a5fb4d761d fix device and dtype 2023-12-04 16:49:13 +05:30
sayakpaul
d17bbbd901 fix more 2023-12-04 16:44:42 +05:30
sayakpaul
8d17831bf8 fix: disable call. 2023-12-04 16:41:11 +05:30
sayakpaul
a7a952dc11 Merge branch 'main' into sdxl/feat 2023-12-04 16:38:37 +05:30
sayakpaul
93b5f92a60 fix imports 2023-12-04 16:29:05 +05:30
sayakpaul
ff28fdd884 reflect patrick's suggestions. 2023-12-04 16:25:08 +05:30
Sayak Paul
7d8b91300c Apply suggestions from code review
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-04 16:10:02 +05:30
sayakpaul
0432297da6 comment out stuff that didn't work 2023-12-04 12:50:14 +05:30
sayakpaul
d944d8b108 add _change_to_group_norm_32 2023-12-04 12:39:41 +05:30
sayakpaul
2632a8b272 remove print. 2023-12-04 12:09:47 +05:30
sayakpaul
4f882ab0df fix: vae -> self.vae 2023-12-04 11:45:11 +05:30
sayakpaul
1688fee353 cast latent to bfloat16 2023-12-04 11:43:42 +05:30
sayakpaul
8c0c3e2f30 remove copy temporarily 2023-12-04 11:39:35 +05:30
sayakpaul
418d33c7f7 style 2023-12-04 11:38:03 +05:30
sayakpaul
44d4263ac4 check latent dtype 2023-12-04 11:35:50 +05:30
sayakpaul
253aaf0d5d bfloat16 computation. 2023-12-04 11:23:58 +05:30
sayakpaul
4e120d86ca check processor 2023-12-04 11:20:41 +05:30
sayakpaul
8da35af8d0 check 2023-12-04 11:19:36 +05:30
sayakpaul
6c5712cdbe Merge branch 'main' into sdxl/feat 2023-12-04 11:13:55 +05:30
sayakpaul
4b66d10240 better conditioning on disable_fused_qkv_projections 2023-12-04 08:19:30 +05:30
sayakpaul
e0848ebbde Empty-Commit 2023-12-03 20:14:49 +05:30
sayakpaul
2c02f0730d correct error message. 2023-12-03 20:04:37 +05:30
sayakpaul
0afc2b455f fix-copies 2023-12-03 10:11:22 +05:30
sayakpaul
be647c3419 Merge branch 'main' into sdxl/feat 2023-12-03 10:09:05 +05:30
sayakpaul
981dc3abfa fix: docs 2023-12-03 10:08:44 +05:30
sayakpaul
c7f78bf54c relax further 2023-12-03 09:56:05 +05:30
sayakpaul
b64e533607 relax assertions. 2023-12-03 09:53:10 +05:30
sayakpaul
e51bc7e744 add: test for qkv projection fusion. 2023-12-03 09:27:10 +05:30
sayakpaul
23f8404bb2 add: documentation and cleanups. 2023-12-03 09:19:36 +05:30
sayakpaul
ba14a08235 merge main and resolve conflicts 2023-12-02 21:03:48 +05:30
sayakpaul
7b1688873f add todos. 2023-12-02 21:02:53 +05:30
sayakpaul
5175b91ccb add fused projection to vae 2023-12-01 17:10:21 +05:30
sayakpaul
32012cea11 remove print 2023-12-01 14:47:43 +05:30
sayakpaul
a0b9066244 _enable_fused_qkv_projections 2023-12-01 14:46:49 +05:30
sayakpaul
06bb65b107 attn processors 2023-12-01 14:45:59 +05:30
sayakpaul
580a1c2dc2 apply attention processor within the method 2023-12-01 14:38:34 +05:30
sayakpaul
678577b920 enable disable 2023-12-01 14:34:55 +05:30
sayakpaul
94fb74a7f8 fix: qkv >-> k 2023-12-01 14:29:49 +05:30
sayakpaul
a7da467125 fix: unbind -> split 2023-12-01 14:29:01 +05:30
sayakpaul
c4eaec3ae6 more print 2023-12-01 13:53:03 +05:30
sayakpaul
01c6038e9d print 2023-12-01 13:47:37 +05:30
sayakpaul
a030797ff1 okay 2023-12-01 13:37:12 +05:30
sayakpaul
86027e52fd dtype 2023-12-01 13:32:12 +05:30
sayakpaul
4e556a92cf device 2023-12-01 13:30:17 +05:30
sayakpaul
c5a5f85293 device. 2023-12-01 13:28:13 +05:30
sayakpaul
a4da76b67c no grad 2023-12-01 13:25:22 +05:30
sayakpaul
f5b091d1be change to a better name 2023-12-01 13:23:50 +05:30
sayakpaul
88c7e16693 feat: introduce fused projections 2023-12-01 13:16:04 +05:30
sayakpaul
bd855d78b3 remove prints 2023-12-01 09:58:36 +05:30
sayakpaul
096fffbb65 comment 2023-12-01 09:54:17 +05:30
sayakpaul
ff04934d41 init_noise_sigma 2023-12-01 09:49:40 +05:30
sayakpaul
75ae3df500 make str 2023-12-01 09:46:20 +05:30
sayakpaul
215bf3b667 turn sigma a list 2023-12-01 09:44:22 +05:30
sayakpaul
55f1842ad3 print 2023-12-01 09:33:56 +05:30
sayakpaul
afb517ae61 from step 2023-12-01 09:24:13 +05:30
sayakpaul
bf4e645a77 debug 2023-12-01 09:14:35 +05:30
10 changed files with 343 additions and 8 deletions

View File

@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

View File

@@ -33,8 +33,8 @@ if is_torch_available():
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]

View File

@@ -113,12 +113,14 @@ class Attention(nn.Module):
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
@@ -180,6 +182,7 @@ class Attention(nn.Module):
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
@@ -692,6 +695,32 @@ class Attention(nn.Module):
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
self.fused_projections = fuse
class AttnProcessor:
r"""
@@ -1184,9 +1213,6 @@ class AttnProcessor2_0:
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1253,6 +1279,103 @@ class AttnProcessor2_0:
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = (
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,

View File

@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

View File

@@ -25,6 +25,7 @@ from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -34,6 +34,7 @@ from ...loaders import (
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline(
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""

View File

@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need

View File

@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
from ...models.attention import Attention
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return max_sigma
return (self.sigmas.max() ** 2 + 1) ** 0.5
return (max_sigma**2 + 1) ** 0.5
@property
def step_index(self):
@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if sigmas.device.type == "cuda":
self.sigmas = self.sigmas.tolist()
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):

View File

@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
sd_pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
sd_pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):