Compare commits

..

4 Commits

Author SHA1 Message Date
Dhruv Nair
858dfd6411 update 2023-12-06 12:14:35 +00:00
Dhruv Nair
6cb2178a91 Revert "fix"
This reverts commit f90a5139a2.
2023-12-06 06:44:02 +00:00
Dhruv Nair
f90a5139a2 fix 2023-12-06 06:03:58 +00:00
Sayak Paul
a2bc2e14b9 [feat] allow SDXL pipeline to run with fused QKV projections (#6030)
* debug

* from step

* print

* turn sigma a list

* make str

* init_noise_sigma

* comment

* remove prints

* feat: introduce fused projections

* change to a better name

* no grad

* device.

* device

* dtype

* okay

* print

* more print

* fix: unbind -> split

* fix: qkv >-> k

* enable disable

* apply attention processor within the method

* attn processors

* _enable_fused_qkv_projections

* remove print

* add fused projection to vae

* add todos.

* add: documentation and cleanups.

* add: test for qkv projection fusion.

* relax assertions.

* relax further

* fix: docs

* fix-copies

* correct error message.

* Empty-Commit

* better conditioning on disable_fused_qkv_projections

* check

* check processor

* bfloat16 computation.

* check latent dtype

* style

* remove copy temporarily

* cast latent to bfloat16

* fix: vae -> self.vae

* remove print.

* add _change_to_group_norm_32

* comment out stuff that didn't work

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* reflect patrick's suggestions.

* fix imports

* fix: disable call.

* fix more

* fix device and dtype

* fix conditions.

* fix more

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-06 07:33:26 +05:30
45 changed files with 589 additions and 250 deletions

View File

@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0
## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

View File

@@ -174,4 +174,10 @@ Set `private=True` in the [`~diffusers.utils.PushToHubMixin.push_to_hub`] functi
controlnet.push_to_hub("my-controlnet-model-private", private=True)
```
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for`. You must be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) to load a model from a private repository.
Private repositories are only visible to you, and other users won't be able to clone the repository and your repository won't appear in search results. Even if a user has the URL to your private repository, they'll receive a `404 - Sorry, we can't find the page you are looking for.`
To load a model, scheduler, or pipeline from private or gated repositories, set `use_auth_token=True`:
```py
model = ControlNetModel.from_pretrained("your-namespace/my-controlnet-model-private", use_auth_token=True)
```

View File

@@ -512,6 +512,7 @@ device = torch.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
safety_checker=None,
use_auth_token=True,
custom_pipeline="imagic_stable_diffusion",
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
).to(device)
@@ -551,6 +552,7 @@ device = th.device('cpu' if not has_cuda else 'cuda')
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="seed_resize_stable_diffusion"
).to(device)
@@ -586,6 +588,7 @@ generator = th.Generator("cuda").manual_seed(0)
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device)
@@ -604,6 +607,7 @@ image.save('./seed_resize/seed_resize_{w}_{h}_image.png'.format(w=width, h=heigh
pipe_compare = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token=True,
custom_pipeline="/home/mark/open_source/diffusers/examples/community/"
).to(device)

View File

@@ -5,11 +5,10 @@ from typing import Dict, List, Union
import safetensors.torch
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from diffusers import DiffusionPipeline, __version__
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
from diffusers.utils import CONFIG_NAME, DIFFUSERS_CACHE, ONNX_WEIGHTS_NAME, WEIGHTS_NAME
class CheckpointMergerPipeline(DiffusionPipeline):
@@ -58,7 +57,6 @@ class CheckpointMergerPipeline(DiffusionPipeline):
return (temp_dict, meta_keys)
@torch.no_grad()
@validate_hf_hub_args
def merge(self, pretrained_model_name_or_path_list: List[Union[str, os.PathLike]], **kwargs):
"""
Returns a new pipeline object of the class 'DiffusionPipeline' with the merged checkpoints(weights) of the models passed
@@ -71,7 +69,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
**kwargs:
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
cache_dir, resume_download, force_download, proxies, local_files_only, use_auth_token, revision, torch_dtype, device_map.
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
@@ -83,12 +81,12 @@ class CheckpointMergerPipeline(DiffusionPipeline):
"""
# Default kwargs from DiffusionPipeline
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
device_map = kwargs.pop("device_map", None)
@@ -125,7 +123,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
config_dicts.append(config_dict)
@@ -161,7 +159,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,

View File

@@ -28,7 +28,6 @@ import PIL.Image
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -51,7 +50,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -779,13 +778,12 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -797,7 +795,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)

View File

@@ -28,7 +28,6 @@ import PIL.Image
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -52,7 +51,7 @@ from diffusers.pipelines.stable_diffusion import (
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -780,13 +779,12 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -798,7 +796,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)

View File

@@ -27,7 +27,6 @@ import onnx_graphsurgeon as gs
import tensorrt as trt
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from onnx import shape_inference
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
@@ -50,7 +49,7 @@ from diffusers.pipelines.stable_diffusion import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import DDIMScheduler
from diffusers.utils import logging
from diffusers.utils import DIFFUSERS_CACHE, logging
"""
@@ -692,13 +691,12 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
self.models["vae"] = make_VAE(self.vae, **models_args)
@classmethod
@validate_hf_hub_args
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
cls.cached_folder = (
@@ -710,7 +708,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
)

View File

@@ -423,7 +423,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]

View File

@@ -392,7 +392,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]

View File

@@ -400,7 +400,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]

View File

@@ -414,7 +414,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]

View File

@@ -420,7 +420,7 @@ def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
)
model_class = text_encoder_config.architectures[0]
@@ -975,7 +975,7 @@ def main(args):
revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True
)
if args.controlnet_model_name_or_path:

View File

@@ -19,7 +19,6 @@ Usage example:
import glob
import json
import warnings
from argparse import ArgumentParser, Namespace
from importlib import import_module
@@ -33,12 +32,12 @@ from . import BaseDiffusersCLICommand
def conversion_command_factory(args: Namespace):
if args.use_auth_token:
warnings.warn(
"The `--use_auth_token` flag is deprecated and will be removed in a future version. Authentication is now"
" handled automatically if user is logged in."
)
return FP16SafetensorsCommand(args.ckpt_id, args.fp16, args.use_safetensors)
return FP16SafetensorsCommand(
args.ckpt_id,
args.fp16,
args.use_safetensors,
args.use_auth_token,
)
class FP16SafetensorsCommand(BaseDiffusersCLICommand):
@@ -63,7 +62,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
)
conversion_parser.set_defaults(func=conversion_command_factory)
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool):
def __init__(self, ckpt_id: str, fp16: bool, use_safetensors: bool, use_auth_token: bool):
self.logger = logging.get_logger("diffusers-cli/fp16_safetensors")
self.ckpt_id = ckpt_id
self.local_ckpt_dir = f"/tmp/{ckpt_id}"
@@ -76,6 +75,8 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
"When `use_safetensors` and `fp16` both are False, then this command is of no use."
)
self.use_auth_token = use_auth_token
def run(self):
if version.parse(huggingface_hub.__version__) < version.parse("0.9.0"):
raise ImportError(
@@ -86,7 +87,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
from huggingface_hub import create_commit
from huggingface_hub._commit_api import CommitOperationAdd
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json")
model_index = hf_hub_download(repo_id=self.ckpt_id, filename="model_index.json", token=self.use_auth_token)
with open(model_index, "r") as f:
pipeline_class_name = json.load(f)["_class_name"]
pipeline_class = getattr(import_module("diffusers"), pipeline_class_name)
@@ -95,7 +96,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
# Load the appropriate pipeline. We could have use `DiffusionPipeline`
# here, but just to avoid any rough edge cases.
pipeline = pipeline_class.from_pretrained(
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32
self.ckpt_id, torch_dtype=torch.float16 if self.fp16 else torch.float32, use_auth_token=self.use_auth_token
)
pipeline.save_pretrained(
self.local_ckpt_dir,

View File

@@ -27,16 +27,12 @@ from typing import Any, Dict, Tuple, Union
import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from . import __version__
from .utils import (
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
DummyObject,
deprecate,
@@ -279,7 +275,6 @@ class ConfigMixin:
return cls.load_config(*args, **kwargs)
@classmethod
@validate_hf_hub_args
def load_config(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -316,7 +311,7 @@ class ConfigMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -334,11 +329,11 @@ class ConfigMixin:
A dictionary of all the parameters stored in a JSON configuration file.
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
@@ -381,7 +376,7 @@ class ConfigMixin:
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
@@ -390,7 +385,8 @@ class ConfigMixin:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
" token having permission to this repo with `token` or log in with `huggingface-cli login`."
" token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
" login`."
)
except RevisionNotFoundError:
raise EnvironmentError(

View File

@@ -15,10 +15,11 @@ import os
from typing import Dict, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_transformers_available,
logging,
@@ -42,7 +43,6 @@ logger = logging.get_logger(__name__)
class IPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -77,7 +77,7 @@ class IPAdapterMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -88,12 +88,12 @@ class IPAdapterMixin:
"""
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
user_agent = {
@@ -110,7 +110,7 @@ class IPAdapterMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,

View File

@@ -18,13 +18,14 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub import model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
@@ -131,7 +132,6 @@ class LoraLoaderMixin:
)
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -174,7 +174,7 @@ class LoraLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -195,12 +195,12 @@ class LoraLoaderMixin:
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -239,7 +239,7 @@ class LoraLoaderMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -265,7 +265,7 @@ class LoraLoaderMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,

View File

@@ -18,9 +18,10 @@ from pathlib import Path
import requests
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
deprecate,
is_accelerate_available,
is_omegaconf_available,
@@ -51,7 +52,6 @@ class FromSingleFileMixin:
return cls.from_single_file(*args, **kwargs)
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
@@ -81,7 +81,7 @@ class FromSingleFileMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -154,12 +154,12 @@ class FromSingleFileMixin:
original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
extract_ema = kwargs.pop("extract_ema", False)
image_size = kwargs.pop("image_size", None)
@@ -253,7 +253,7 @@ class FromSingleFileMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
@@ -293,7 +293,6 @@ class FromOriginalVAEMixin:
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -323,7 +322,7 @@ class FromOriginalVAEMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -380,12 +379,12 @@ class FromOriginalVAEMixin:
)
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
image_size = kwargs.pop("image_size", None)
scaling_factor = kwargs.pop("scaling_factor", None)
@@ -426,7 +425,7 @@ class FromOriginalVAEMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)
@@ -491,7 +490,6 @@ class FromOriginalControlnetMixin:
"""
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
r"""
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
@@ -521,7 +519,7 @@ class FromOriginalControlnetMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to True, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -557,12 +555,12 @@ class FromOriginalControlnetMixin:
from ..pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
config_file = kwargs.pop("config_file", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
num_in_channels = kwargs.pop("num_in_channels", None)
use_linear_projection = kwargs.pop("use_linear_projection", None)
revision = kwargs.pop("revision", None)
@@ -605,7 +603,7 @@ class FromOriginalControlnetMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
force_download=force_download,
)

View File

@@ -15,10 +15,16 @@ from typing import Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_accelerate_available,
is_transformers_available,
logging,
)
if is_transformers_available():
@@ -33,14 +39,13 @@ TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@validate_hf_hub_args
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -74,7 +79,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -95,7 +100,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -262,7 +267,6 @@ class TextualInversionLoaderMixin:
return all_tokens, all_embeddings
@validate_hf_hub_args
def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
@@ -316,7 +320,7 @@ class TextualInversionLoaderMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):

View File

@@ -19,12 +19,13 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn
from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file,
delete_adapter_layers,
@@ -61,7 +62,6 @@ class UNet2DConditionLoadersMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
@validate_hf_hub_args
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
@@ -95,7 +95,7 @@ class UNet2DConditionLoadersMixin:
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
@@ -130,12 +130,12 @@ class UNet2DConditionLoadersMixin:
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
@@ -184,7 +184,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -204,7 +204,7 @@ class UNet2DConditionLoadersMixin:
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,

View File

@@ -33,8 +33,8 @@ if is_torch_available():
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]

View File

@@ -113,12 +113,14 @@ class Attention(nn.Module):
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
@@ -180,6 +182,7 @@ class Attention(nn.Module):
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
@@ -692,6 +695,32 @@ class Attention(nn.Module):
return encoder_hidden_states
@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)
self.fused_projections = fuse
class AttnProcessor:
r"""
@@ -1184,9 +1213,6 @@ class AttnProcessor2_0:
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1253,6 +1279,103 @@ class AttnProcessor2_0:
return hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is currently 🧪 experimental in nature and can change in future.
</Tip>
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)
kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -2251,6 +2374,7 @@ CROSS_ATTENTION_PROCESSORS = (
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,

View File

@@ -22,6 +22,7 @@ from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -448,3 +449,41 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

View File

@@ -24,17 +24,13 @@ from flax.core.frozen_dict import FrozenDict, unfreeze
from flax.serialization import from_bytes, to_bytes
from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_NAME,
@@ -201,7 +197,6 @@ class FlaxModelMixin(PushToHubMixin):
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
@@ -293,13 +288,13 @@ class FlaxModelMixin(PushToHubMixin):
```
"""
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
from_pt = kwargs.pop("from_pt", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
@@ -319,7 +314,7 @@ class FlaxModelMixin(PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
**kwargs,
@@ -364,7 +359,7 @@ class FlaxModelMixin(PushToHubMixin):
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
@@ -374,7 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `token` or log in with `huggingface-cli "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:

View File

@@ -25,13 +25,14 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import safetensors
import torch
from huggingface_hub import create_repo
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
from .. import __version__
from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
@@ -534,7 +535,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
@@ -571,7 +571,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -640,15 +640,15 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
@@ -718,7 +718,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
@@ -740,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -763,7 +763,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
@@ -782,7 +782,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,

View File

@@ -25,6 +25,7 @@ from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
@@ -794,6 +795,42 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -16,9 +16,8 @@
import inspect
from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import DIFFUSERS_CACHE
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
@@ -196,7 +195,6 @@ class AutoPipelineForText2Image(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a text-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -248,7 +246,7 @@ class AutoPipelineForText2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -312,11 +310,11 @@ class AutoPipelineForText2Image(ConfigMixin):
>>> image = pipeline(prompt).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -325,7 +323,7 @@ class AutoPipelineForText2Image(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}
@@ -468,7 +466,6 @@ class AutoPipelineForImage2Image(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a image-to-image Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -521,7 +518,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -585,11 +582,11 @@ class AutoPipelineForImage2Image(ConfigMixin):
>>> image = pipeline(prompt, image).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -598,7 +595,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}
@@ -745,7 +742,6 @@ class AutoPipelineForInpainting(ConfigMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path, **kwargs):
r"""
Instantiates a inpainting Pytorch diffusion pipeline from pretrained pipeline weight.
@@ -797,7 +793,7 @@ class AutoPipelineForInpainting(ConfigMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -861,11 +857,11 @@ class AutoPipelineForInpainting(ConfigMixin):
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
@@ -874,7 +870,7 @@ class AutoPipelineForInpainting(ConfigMixin):
"force_download": force_download,
"resume_download": resume_download,
"proxies": proxies,
"token": token,
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
}

View File

@@ -22,7 +22,6 @@ from typing import Optional, Union
import numpy as np
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, is_onnx_available, logging
@@ -131,11 +130,10 @@ class OnnxRuntimeModel:
self._save_pretrained(save_directory, **kwargs)
@classmethod
@validate_hf_hub_args
def _from_pretrained(
cls,
model_id: Union[str, Path],
token: Optional[Union[bool, str, None]] = None,
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
@@ -150,7 +148,7 @@ class OnnxRuntimeModel:
Arguments:
model_id (`str` or `Path`):
Directory from which to load
token (`str` or `bool`):
use_auth_token (`str` or `bool`):
Is needed to load models from a private or gated repository
revision (`str`):
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
@@ -181,7 +179,7 @@ class OnnxRuntimeModel:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=model_file_name,
token=token,
use_auth_token=use_auth_token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
@@ -192,12 +190,11 @@ class OnnxRuntimeModel:
return cls(model=model, **kwargs)
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
model_id: Union[str, Path],
force_download: bool = True,
token: Optional[str] = None,
use_auth_token: Optional[str] = None,
cache_dir: Optional[str] = None,
**model_kwargs,
):
@@ -210,6 +207,6 @@ class OnnxRuntimeModel:
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
token=token,
use_auth_token=use_auth_token,
**model_kwargs,
)

View File

@@ -24,7 +24,6 @@ import numpy as np
import PIL.Image
from flax.core.frozen_dict import FrozenDict
from huggingface_hub import create_repo, snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
from PIL import Image
from tqdm.auto import tqdm
@@ -33,6 +32,7 @@ from ..models.modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from ..schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from ..utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
BaseOutput,
PushToHubMixin,
http_user_agent,
@@ -227,7 +227,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a Flax-based diffusion pipeline from pretrained pipeline weights.
@@ -265,7 +264,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -315,11 +314,11 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> dpm_params["scheduler"] = dpmpp_state
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
token = kwargs.pop("token", None)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_pt = kwargs.pop("from_pt", False)
use_memory_efficient_attention = kwargs.pop("use_memory_efficient_attention", False)
@@ -335,7 +334,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
@@ -366,7 +365,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,

View File

@@ -28,14 +28,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from huggingface_hub import (
ModelCard,
create_repo,
hf_hub_download,
model_info,
snapshot_download,
)
from huggingface_hub.utils import validate_hf_hub_args
from huggingface_hub import ModelCard, create_repo, hf_hub_download, model_info, snapshot_download
from packaging import version
from requests.exceptions import HTTPError
from tqdm.auto import tqdm
@@ -47,6 +40,8 @@ from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
@@ -254,11 +249,10 @@ def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLi
return usable_filenames, variant_filenames
@validate_hf_hub_args
def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames):
def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, variant, revision, model_filenames):
info = model_info(
pretrained_model_name_or_path,
token=token,
use_auth_token=use_auth_token,
revision=None,
)
filenames = {sibling.rfilename for sibling in info.siblings}
@@ -381,6 +375,7 @@ def _get_pipeline_class(
custom_pipeline,
module_file=file_name,
class_name=class_name,
repo_id=repo_id,
cache_dir=cache_dir,
revision=revision,
)
@@ -914,7 +909,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return torch.float32
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pretrained pipeline weights.
@@ -982,7 +976,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -1062,12 +1056,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
>>> pipeline.scheduler = scheduler
```
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
@@ -1100,7 +1094,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
from_flax=from_flax,
use_safetensors=use_safetensors,
@@ -1305,7 +1299,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"use_auth_token": use_auth_token,
"revision": revision,
"torch_dtype": torch_dtype,
"custom_pipeline": custom_pipeline,
@@ -1535,7 +1529,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
cpu_offload(model, device, offload_buffers=offload_buffers)
@classmethod
@validate_hf_hub_args
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
r"""
Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights.
@@ -1583,7 +1576,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
@@ -1626,12 +1619,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
custom_pipeline = kwargs.pop("custom_pipeline", None)
@@ -1653,7 +1646,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model_info_call_error: Optional[Exception] = None
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
info = model_info(
pretrained_model_name,
use_auth_token=use_auth_token,
revision=revision,
)
except HTTPError as e:
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
@@ -1668,7 +1665,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
proxies=proxies,
force_download=force_download,
resume_download=resume_download,
token=token,
use_auth_token=use_auth_token,
)
config_dict = cls._dict_from_json_file(config_file)
@@ -1718,7 +1715,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
warn_deprecated_model_variant(
pretrained_model_name, use_auth_token, variant, revision, model_filenames
)
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
@@ -1860,7 +1859,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
@@ -1884,7 +1883,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"use_auth_token": use_auth_token,
"variant": variant,
"use_safetensors": use_safetensors,
}

View File

@@ -34,6 +34,7 @@ from ...loaders import (
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -681,7 +682,6 @@ class StableDiffusionXLPipeline(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
@@ -692,6 +692,7 @@ class StableDiffusionXLPipeline(
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
@@ -729,6 +730,65 @@ class StableDiffusionXLPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""

View File

@@ -24,6 +24,7 @@ from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, Te
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
@@ -610,6 +611,7 @@ class StableDiffusionXLInstructPix2PixPipeline(
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need

View File

@@ -10,10 +10,10 @@ from diffusers.utils import deprecate
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.activations import get_activation
from ...models.attention import Attention
from ...models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
@@ -1000,6 +1000,42 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
sample: torch.FloatTensor,

View File

@@ -191,10 +191,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
max_sigma = max(self.sigmas) if isinstance(self.sigmas, list) else self.sigmas.max()
if self.config.timestep_spacing in ["linspace", "trailing"]:
return self.sigmas.max()
return max_sigma
return (self.sigmas.max() ** 2 + 1) ** 0.5
return (max_sigma**2 + 1) ** 0.5
@property
def step_index(self):
@@ -289,6 +290,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if sigmas.device.type == "cuda":
self.sigmas = self.sigmas.tolist()
self._step_index = None
def _sigma_to_t(self, sigma, log_sigmas):

View File

@@ -18,7 +18,6 @@ from enum import Enum
from typing import Optional, Union
import torch
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin
@@ -82,7 +81,6 @@ class SchedulerMixin(PushToHubMixin):
has_compatibles = True
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
@@ -122,7 +120,7 @@ class SchedulerMixin(PushToHubMixin):
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):

View File

@@ -20,7 +20,6 @@ from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import BaseOutput, PushToHubMixin
@@ -71,7 +70,6 @@ class FlaxSchedulerMixin(PushToHubMixin):
has_compatibles = True
@classmethod
@validate_hf_hub_args
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
@@ -112,7 +110,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):

View File

@@ -21,6 +21,7 @@ from .. import __version__
from .constants import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE,
@@ -37,6 +38,7 @@ from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
from .hub_utils import (
HF_HUB_OFFLINE,
PushToHubMixin,
_add_variant,
_get_model_file,

View File

@@ -14,13 +14,15 @@
import importlib
import os
from huggingface_hub.constants import HF_HOME
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
from ..dependency_versions_check import dep_version_check
from .import_utils import ENV_VARS_TRUE_VALUES, is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE
MIN_PEFT_VERSION = "0.6.0"
MIN_TRANSFORMERS_VERSION = "4.34.0"
_CHECK_PEFT = os.environ.get("_CHECK_PEFT", "1") in ENV_VARS_TRUE_VALUES
@@ -33,8 +35,9 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
# Below should be `True` if the current version of `peft` and `transformers` are compatible with

View File

@@ -25,8 +25,7 @@ from pathlib import Path
from typing import Dict, Optional, Union
from urllib import request
from huggingface_hub import cached_download, hf_hub_download, model_info
from huggingface_hub.utils import validate_hf_hub_args
from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info
from packaging import version
from .. import __version__
@@ -195,7 +194,6 @@ def find_pipeline_class(loaded_module):
return pipeline_class
@validate_hf_hub_args
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
@@ -203,7 +201,7 @@ def get_cached_module_file(
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
):
@@ -234,7 +232,7 @@ def get_cached_module_file(
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
@@ -246,7 +244,7 @@ def get_cached_module_file(
<Tip>
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
@@ -291,7 +289,7 @@ def get_cached_module_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=False,
use_auth_token=False,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
@@ -309,7 +307,7 @@ def get_cached_module_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
except EnvironmentError:
@@ -334,6 +332,13 @@ def get_cached_module_file(
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
@@ -354,14 +359,13 @@ def get_cached_module_file(
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)
return os.path.join(full_submodule, module_file)
@validate_hf_hub_args
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
@@ -370,7 +374,7 @@ def get_class_from_dynamic_module(
force_download: bool = False,
resume_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
**kwargs,
@@ -410,7 +414,7 @@ def get_class_from_dynamic_module(
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or `bool`, *optional*):
use_auth_token (`str` or `bool`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
@@ -422,7 +426,7 @@ def get_class_from_dynamic_module(
<Tip>
You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
You may pass a token in `use_auth_token` if you are not logged in (`huggingface-cli long`) and want to use private
or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
@@ -445,7 +449,7 @@ def get_class_from_dynamic_module(
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
)

View File

@@ -25,21 +25,20 @@ from typing import Dict, Optional, Union
from uuid import uuid4
from huggingface_hub import (
HfFolder,
ModelCard,
ModelCardData,
create_repo,
get_full_repo_name,
hf_hub_download,
upload_folder,
whoami,
)
from huggingface_hub.constants import HF_HUB_CACHE, HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
validate_hf_hub_args,
)
from packaging import version
from requests import HTTPError
@@ -47,6 +46,7 @@ from requests import HTTPError
from .. import __version__
from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
@@ -69,6 +69,9 @@ logger = get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
SESSION_ID = uuid4().hex
HF_HUB_OFFLINE = os.getenv("HF_HUB_OFFLINE", "").upper() in ENV_VARS_TRUE_VALUES
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
HUGGINGFACE_CO_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/"
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
@@ -76,7 +79,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
Formats a user-agent string with basic info about a request.
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if HF_HUB_DISABLE_TELEMETRY or HF_HUB_OFFLINE:
if DISABLE_TELEMETRY or HF_HUB_OFFLINE:
return ua + "; telemetry/off"
if is_torch_available():
ua += f"; torch/{_torch_version}"
@@ -95,6 +98,16 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def create_model_card(args, model_name):
if not is_jinja_available():
raise ValueError(
@@ -170,7 +183,7 @@ old_diffusers_cache = os.path.join(hf_cache_home, "diffusers")
def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] = None) -> None:
if new_cache_dir is None:
new_cache_dir = HF_HUB_CACHE
new_cache_dir = DIFFUSERS_CACHE
if old_cache_dir is None:
old_cache_dir = old_diffusers_cache
@@ -190,7 +203,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
# At this point, old_cache_dir contains symlinks to the new cache (it can still be used).
cache_version_file = os.path.join(HF_HUB_CACHE, "version_diffusers_cache.txt")
cache_version_file = os.path.join(DIFFUSERS_CACHE, "version_diffusers_cache.txt")
if not os.path.isfile(cache_version_file):
cache_version = 0
else:
@@ -220,12 +233,12 @@ if cache_version < 1:
if cache_version < 1:
try:
os.makedirs(HF_HUB_CACHE, exist_ok=True)
os.makedirs(DIFFUSERS_CACHE, exist_ok=True)
with open(cache_version_file, "w") as f:
f.write("1")
except Exception:
logger.warning(
f"There was a problem when trying to write in your cache folder ({HF_HUB_CACHE}). Please, ensure "
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
"the directory exists and can be written to."
)
@@ -239,21 +252,20 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name
@validate_hf_hub_args
def _get_model_file(
pretrained_model_name_or_path: Union[str, Path],
pretrained_model_name_or_path,
*,
weights_name: str,
subfolder: Optional[str],
cache_dir: Optional[str],
force_download: bool,
proxies: Optional[Dict],
resume_download: bool,
local_files_only: bool,
token: Optional[str],
user_agent: Union[Dict, str, None],
revision: Optional[str],
commit_hash: Optional[str] = None,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
@@ -288,7 +300,7 @@ def _get_model_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
@@ -313,7 +325,7 @@ def _get_model_file(
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
@@ -324,7 +336,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `token` or log in with `huggingface-cli "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:

View File

View File

@@ -164,7 +164,7 @@ class PriorTransformerIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
@parameterized.expand(
[

View File

@@ -869,7 +869,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32

View File

@@ -485,7 +485,7 @@ class AutoencoderTinyIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -565,7 +565,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
@@ -820,7 +820,7 @@ class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
dtype = torch.float16 if fp16 else torch.float32

View File

@@ -310,7 +310,7 @@ class StableDiffusion2PipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"
@@ -531,7 +531,7 @@ class StableDiffusion2PipelineNightlyTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache()
backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
_generator_device = "cpu" if not generator_device.startswith("cuda") else "cuda"

View File

@@ -938,6 +938,37 @@ class StableDiffusionXLPipelineFastTests(
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_stable_diffusion_xl_with_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
original_image_slice = image[0, -3:, -3:, -1]
sd_pipe.fuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_fused = image[0, -3:, -3:, -1]
sd_pipe.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-2, rtol=1e-2
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@slow
class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):