Compare commits

..

2 Commits

Author SHA1 Message Date
Patrick von Platen
95de8000ec Merge branch 'main' of https://github.com/huggingface/diffusers into post_release_0260 2024-02-09 16:13:18 +00:00
Patrick von Platen
2dfc2e8c47 post release 2024-02-01 00:37:04 +02:00
34 changed files with 2806 additions and 612 deletions

View File

@@ -34,6 +34,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
- name: Fast Flax CPU tests
framework: flax
runner: docker-cpu
@@ -89,6 +94,14 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- name: Run fast PyTorch LoRA CPU tests
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |

View File

@@ -26,9 +26,9 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
python3.9 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
invisible_watermark && \
python3.9 -m pip install --no-cache-dir \
accelerate \

View File

@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
invisible_watermark \
--extra-index-url https://download.pytorch.org/whl/cpu && \
python3 -m pip install --no-cache-dir \

View File

@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
invisible_watermark && \
python3 -m pip install --no-cache-dir \
accelerate \

View File

@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
python3 -m pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
torch==2.1.2 \
torchvision==0.16.2 \
torchaudio==2.1.2 \
invisible_watermark && \
python3 -m pip install --no-cache-dir \
accelerate \

View File

@@ -56,60 +56,6 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```
### Load from a local file
Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a `pipeline.py` file that contains the pipeline class in order to successfully load it.
```py
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
custom_pipeline="./path/to/pipeline_directory/",
clip_model=clip_model,
feature_extractor=feature_extractor,
use_safetensors=True,
)
```
### Load from a specific version
By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
<hfoptions id="version">
<hfoption id="main">
For example, to load from the `main` branch:
```py
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
custom_pipeline="clip_guided_stable_diffusion",
custom_revision="main",
clip_model=clip_model,
feature_extractor=feature_extractor,
use_safetensors=True,
)
```
</hfoption>
<hfoption id="older version">
For example, to load from a previous version of Diffusers like `v0.25.0`:
```py
pipeline = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
custom_pipeline="clip_guided_stable_diffusion",
custom_revision="v0.25.0",
clip_model=clip_model,
feature_extractor=feature_extractor,
use_safetensors=True,
)
```
</hfoption>
</hfoptions>
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
## Community components

View File

@@ -42,7 +42,6 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0")
class MarigoldDepthOutput(BaseOutput):
"""
Output class for Marigold monocular depth prediction pipeline.

View File

@@ -376,14 +376,18 @@ After training, LoRA weights can be loaded very easily into the original pipelin
load the original pipeline:
```python
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("base-model-name").to("cuda")
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
```
Next, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
```python
pipe.load_lora_weights("path-to-the-lora-checkpoint")
pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
```
Finally, we can run the model in inference.

View File

@@ -49,7 +49,6 @@ from diffusers import (
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -196,7 +195,7 @@ def import_model_class_from_model_name_or_path(
raise ValueError(f"{model_class} is not supported.")
def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
@@ -210,25 +209,27 @@ def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = Non
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- t2iadapter
inference: true
---
"""
model_card = f"""
# t2iadapter-{repo_id}
These are t2iadapter weights trained on {base_model} with new type of conditioning.
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def parse_args(input_args=None):

View File

@@ -45,7 +45,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -67,8 +66,8 @@ DATASET_NAME_MAPPING = {
def save_model_card(
args,
repo_id: str,
images: list = None,
repo_folder: str = None,
images=None,
repo_folder=None,
):
img_str = ""
if len(images) > 0:
@@ -76,7 +75,21 @@ def save_model_card(
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# Text-to-image finetuning - {repo_id}
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
@@ -119,21 +132,10 @@ These are the key hyperparameters used during training:
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""
model_description += wandb_info
model_card += wandb_info
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=args.pretrained_model_name_or_path,
model_description=model_description,
inference=True,
)
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):

View File

@@ -45,7 +45,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDif
from diffusers.optimization import get_scheduler
from diffusers.training_utils import cast_training_params, compute_snr
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -56,39 +55,32 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__, log_level="INFO")
def save_model_card(
repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None
):
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion",
"stable-diffusion-diffusers",
"text-to-image",
"diffusers",
"lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def parse_args():

View File

@@ -58,7 +58,6 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -71,20 +70,33 @@ logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images: list = None,
base_model: str = None,
dataset_name: str = None,
train_text_encoder: bool = False,
repo_folder: str = None,
vae_path: str = None,
images=None,
base_model=str,
dataset_name=str,
train_text_encoder=False,
repo_folder=None,
vae_path=None,
):
img_str = ""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
dataset: {dataset_name}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
inference: true
---
"""
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
@@ -94,19 +106,8 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def import_model_class_from_model_name_or_path(

View File

@@ -48,7 +48,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, U
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -66,45 +65,41 @@ DATASET_NAME_MAPPING = {
def save_model_card(
repo_id: str,
images: list = None,
validation_prompt: str = None,
base_model: str = None,
dataset_name: str = None,
repo_folder: str = None,
vae_path: str = None,
images=None,
validation_prompt=None,
base_model=str,
dataset_name=str,
repo_folder=None,
vae_path=None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
dataset: {dataset_name}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
inference: true
---
"""
model_card = f"""
# Text-to-image finetuning - {repo_id}
This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
{img_str}
Special VAE used for training: {vae_path}.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
def import_model_class_from_model_name_or_path(

View File

@@ -167,10 +167,7 @@ vae_conversion_map_attn = [
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
if not w.ndim == 1:
return w.reshape(*w.shape, 1, 1)
else:
return w
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
@@ -324,18 +321,11 @@ if __name__ == "__main__":
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert text encoder 1
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
# Convert text encoder 2
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
# We call the `.T.contiguous()` to match what's done in
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
"conditioner.embedders.1.model.text_projection.weight"
).T.contiguous()
# Put together new checkpoint
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}

View File

@@ -170,10 +170,7 @@ vae_extra_conversion_map = [
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
if not w.ndim == 1:
return w.reshape(*w.shape, 1, 1)
else:
return w
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):

View File

@@ -126,8 +126,8 @@ _deps = [
"regex!=2019.12.17",
"requests",
"tensorboard",
"torch>=1.4",
"torchvision",
"torch>=1.4,<2.2.0",
"torchvision<0.17",
"transformers>=4.25.1",
"urllib3<=2.0.0",
]

View File

@@ -38,8 +38,8 @@ deps = {
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
"torch": "torch>=1.4",
"torchvision": "torchvision",
"torch": "torch>=1.4,<2.2.0",
"torchvision": "torchvision<0.17",
"transformers": "transformers>=4.25.1",
"urllib3": "urllib3<=2.0.0",
}

View File

@@ -166,7 +166,8 @@ class IPAdapterMixin:
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import os
from contextlib import nullcontext
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
@@ -25,7 +26,7 @@ from packaging import version
from torch import nn
from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
@@ -33,6 +34,7 @@ from ..utils import (
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
@@ -49,9 +51,10 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
if is_transformers_available():
from transformers import PreTrainedModel
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
@@ -103,9 +106,6 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
@@ -397,17 +397,16 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
@@ -428,7 +427,9 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if len(state_dict.keys()) > 0:
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
@@ -517,11 +518,6 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -543,21 +539,34 @@ class LoraLoaderMixin:
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
if USE_PEFT_BACKEND:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
if network_alphas is not None:
alpha_keys = [
@@ -567,25 +576,84 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config = LoraConfig(**lora_config_kwargs)
if USE_PEFT_BACKEND:
from peft import LoraConfig
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
else:
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@@ -621,8 +689,6 @@ class LoraLoaderMixin:
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys())
@@ -639,6 +705,8 @@ class LoraLoaderMixin:
}
if len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
@@ -686,20 +754,118 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
remove_method = recurse_remove_peft_layers
if USE_PEFT_BACKEND:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
# In case text encoder have no Lora attached
if getattr(self.text_encoder, "peft_config", None) is not None:
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if getattr(self.text_encoder_2, "peft_config", None) is not None:
if USE_PEFT_BACKEND:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
attn_module.k_proj.lora_linear_layer = None
attn_module.v_proj.lora_linear_layer = None
attn_module.out_proj.lora_linear_layer = None
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_linear_layer = None
mlp_module.fc2.lora_linear_layer = None
@classmethod
def _modify_text_encoder(
cls,
text_encoder,
lora_scale=1,
network_alphas=None,
rank: Union[Dict[str, int], int] = 4,
dtype=None,
patch_mlp=False,
low_cpu_mem_usage=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
lora_parameters.extend(model.lora_linear_layer.parameters())
return model
# First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
is_network_alphas_populated = len(network_alphas) > 0
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
if isinstance(rank, dict):
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
else:
current_rank = rank
attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
)
attn_module.k_proj = create_patched_linear_lora(
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
)
attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
)
attn_module.out_proj = create_patched_linear_lora(
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
)
if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
)
mlp_module.fc2 = create_patched_linear_lora(
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
)
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
return lora_parameters
@classmethod
def save_lora_weights(
cls,
@@ -873,8 +1039,6 @@ class LoraLoaderMixin:
pipeline.fuse_lora(lora_scale=0.7)
```
"""
from peft.tuners.tuners_utils import BaseTunerLayer
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
@@ -886,26 +1050,52 @@ class LoraLoaderMixin:
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
merge_kwargs = {"safe_merge": safe_fusing}
if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
merge_kwargs = {"safe_merge": safe_fusing}
# For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge(**merge_kwargs)
# For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. "
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
else:
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
raise ValueError(
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
"backend to use this argument by installing latest PEFT and transformers."
" `pip install -U peft transformers`"
)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
@@ -930,18 +1120,40 @@ class LoraLoaderMixin:
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if unfuse_unet:
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
if not USE_PEFT_BACKEND:
unet.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
for module in unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
@@ -1222,9 +1434,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
@@ -1329,13 +1538,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
if getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if USE_PEFT_BACKEND:
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
if getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -1112,6 +1112,7 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
else:
text_model_dict[diffusers_key] = checkpoint[key]

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from importlib import import_module
from typing import Callable, Optional, Union
@@ -510,15 +509,6 @@ class Attention(nn.Module):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
return self.processor(
self,
hidden_states,

View File

@@ -20,7 +20,7 @@ from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlNetMixin
from ..utils import BaseOutput, deprecate, logging
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -43,24 +43,6 @@ from .unets.unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, block_out_channels):
incorrect_attention_head_dim_name = False
if "CrossAttnDownBlock2D" in down_block_types or mid_block_type == "UNetMidBlock2DCrossAttn":
incorrect_attention_head_dim_name = True
if incorrect_attention_head_dim_name:
num_attention_heads = attention_head_dim
else:
# we use attention_head_dim to calculate num_attention_heads
if isinstance(attention_head_dim, int):
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
else:
num_attention_heads = [
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
]
return num_attention_heads
@dataclass
class ControlNetOutput(BaseOutput):
"""
@@ -240,22 +222,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
):
super().__init__()
if attention_head_dim is not None:
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads`."
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
num_attention_heads = correct_incorrect_names(
attention_head_dim, down_block_types, mid_block_type, block_out_channels
)
logger.warning(
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
f" the model will be configured with `num_attention_heads` {num_attention_heads}."
)
attention_head_dim = None
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if num_attention_heads is None:
raise ValueError("`num_attention_heads` cannot be None.")
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
@@ -270,13 +245,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# we use num_attention_heads to calculate attention_head_dim
attention_head_dim = [
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
]
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
@@ -386,6 +354,12 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
@@ -411,7 +385,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
@@ -448,7 +422,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,

View File

@@ -27,7 +27,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils import logging
from ..utils.import_utils import is_transformers_available
@@ -82,9 +82,6 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
class PatchedLoraProjection(torch.nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
super().__init__()
from ..models.lora import LoRALinearLayer
@@ -296,16 +293,10 @@ class LoRACompatibleConv(nn.Conv2d):
"""
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
@@ -380,15 +371,10 @@ class LoRACompatibleLinear(nn.Linear):
"""
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
self.lora_layer = lora_layer
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("set_lora_layer", "1.0.0", deprecation_message)
self.lora_layer = lora_layer
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):

View File

@@ -119,7 +119,6 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
downsample_type=downsample_type,
@@ -141,7 +140,6 @@ def get_down_block(
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
@@ -163,7 +161,6 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
@@ -194,7 +191,6 @@ def get_down_block(
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -222,7 +218,6 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -248,7 +243,6 @@ def get_down_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
add_self_attention=True if not add_downsample else False,
)
@@ -341,7 +335,6 @@ def get_up_block(
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
@@ -365,7 +358,6 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
@@ -390,7 +382,6 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
@@ -421,7 +412,6 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
@@ -450,7 +440,6 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
@@ -479,7 +468,6 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
@@ -567,7 +555,6 @@ class UNetMidBlock2D(nn.Module):
attn_groups: Optional[int] = None,
resnet_pre_norm: bool = True,
add_attention: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
):
@@ -615,15 +602,13 @@ class UNetMidBlock2D(nn.Module):
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
)
attention_head_dim = in_channels
if num_attention_heads is None:
num_attention_heads = in_channels // attention_head_dim
for _ in range(num_layers):
if self.add_attention:
attentions.append(
Attention(
in_channels,
heads=num_attention_heads,
heads=in_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -695,7 +680,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
dual_cross_attention: bool = False,
@@ -709,9 +693,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
if attention_head_dim is None:
attention_head_dim = in_channels // num_attention_heads
# support for variable transformer layers per block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -737,8 +718,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -751,8 +732,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -843,7 +824,6 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
@@ -858,9 +838,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
self.attention_head_dim = attention_head_dim
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
if num_attention_heads is None:
num_attention_heads = in_channels // attention_head_dim
self.num_heads = num_attention_heads
self.num_heads = in_channels // self.attention_head_dim
# there is always at least one resnet
resnets = [
@@ -971,7 +949,6 @@ class AttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
@@ -988,9 +965,6 @@ class AttnDownBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -1010,7 +984,7 @@ class AttnDownBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1100,7 +1074,6 @@ class CrossAttnDownBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
@@ -1117,9 +1090,6 @@ class CrossAttnDownBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if attention_head_dim is None:
attention_head_dim = out_channels // num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -1142,8 +1112,8 @@ class CrossAttnDownBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -1157,8 +1127,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -1425,7 +1395,6 @@ class AttnDownEncoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
@@ -1441,9 +1410,6 @@ class AttnDownEncoderBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
if resnet_time_scale_shift == "spatial":
@@ -1478,7 +1444,7 @@ class AttnDownEncoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1529,7 +1495,6 @@ class AttnSkipDownBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = np.sqrt(2.0),
add_downsample: bool = True,
@@ -1544,9 +1509,6 @@ class AttnSkipDownBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
@@ -1567,7 +1529,7 @@ class AttnSkipDownBlock2D(nn.Module):
self.attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -1827,7 +1789,6 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
@@ -1844,9 +1805,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
attentions = []
self.attention_head_dim = attention_head_dim
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.num_heads = num_attention_heads
self.num_heads = out_channels // self.attention_head_dim
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
@@ -1874,7 +1833,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=num_attention_heads,
heads=self.num_heads,
dim_head=attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -2068,7 +2027,6 @@ class KCrossAttnDownBlock2D(nn.Module):
num_layers: int = 4,
resnet_group_size: int = 32,
add_downsample: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 64,
add_self_attention: bool = False,
resnet_eps: float = 1e-5,
@@ -2078,9 +2036,6 @@ class KCrossAttnDownBlock2D(nn.Module):
resnets = []
attentions = []
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.has_cross_attention = True
for i in range(num_layers):
@@ -2104,9 +2059,9 @@ class KCrossAttnDownBlock2D(nn.Module):
)
attentions.append(
KAttentionBlock(
dim=out_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
out_channels,
out_channels // attention_head_dim,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,
@@ -2203,7 +2158,6 @@ class AttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
upsample_type: str = "conv",
@@ -2220,9 +2174,6 @@ class AttnUpBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -2244,7 +2195,7 @@ class AttnUpBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -2329,7 +2280,6 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
attention_head_dim: Optional[int] = None,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
@@ -2346,9 +2296,6 @@ class CrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
if attention_head_dim is None:
attention_head_dim = out_channels // num_attention_heads
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
@@ -2373,8 +2320,8 @@ class CrossAttnUpBlock2D(nn.Module):
if not dual_cross_attention:
attentions.append(
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
@@ -2388,8 +2335,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
attentions.append(
DualTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
@@ -2687,7 +2634,6 @@ class AttnUpDecoderBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
@@ -2703,9 +2649,6 @@ class AttnUpDecoderBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
@@ -2742,7 +2685,7 @@ class AttnUpDecoderBlock2D(nn.Module):
attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -2794,7 +2737,6 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
output_scale_factor: float = np.sqrt(2.0),
add_upsample: bool = True,
@@ -2829,13 +2771,10 @@ class AttnSkipUpBlock2D(nn.Module):
)
attention_head_dim = out_channels
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.attentions.append(
Attention(
out_channels,
heads=num_attention_heads,
heads=out_channels // attention_head_dim,
dim_head=attention_head_dim,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
@@ -3143,7 +3082,6 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
@@ -3159,9 +3097,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.attention_head_dim = attention_head_dim
if num_attention_heads is None:
num_attention_heads = out_channels // attention_head_dim
self.num_heads = num_attention_heads
self.num_heads = out_channels // self.attention_head_dim
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
@@ -3191,8 +3127,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
Attention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=num_attention_heads,
dim_head=attention_head_dim,
heads=self.num_heads,
dim_head=self.attention_head_dim,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
@@ -3398,7 +3334,6 @@ class KCrossAttnUpBlock2D(nn.Module):
resnet_eps: float = 1e-5,
resnet_act_fn: str = "gelu",
resnet_group_size: int = 32,
num_attention_heads: Optional[int] = None,
attention_head_dim: int = 1, # attention dim_head
cross_attention_dim: int = 768,
add_upsample: bool = True,
@@ -3415,11 +3350,6 @@ class KCrossAttnUpBlock2D(nn.Module):
self.has_cross_attention = True
self.attention_head_dim = attention_head_dim
if num_attention_heads is not None:
logger.warn(
"`num_attention_heads` argument is passed but ignored. The number of attention heads is determined by `attention_head_dim`, `in_channels` and `out_channels`."
)
# in_channels, and out_channels for the block (k-unet)
k_in_channels = out_channels if is_first_block else 2 * out_channels
k_out_channels = in_channels
@@ -3453,11 +3383,11 @@ class KCrossAttnUpBlock2D(nn.Module):
)
attentions.append(
KAttentionBlock(
dim=k_out_channels if (i == num_layers - 1) else out_channels,
num_attention_heads=k_out_channels // attention_head_dim
k_out_channels if (i == num_layers - 1) else out_channels,
k_out_channels // attention_head_dim
if (i == num_layers - 1)
else out_channels // attention_head_dim,
attention_head_dim=attention_head_dim,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
temb_channels=temb_channels,
attention_bias=True,

View File

@@ -55,28 +55,6 @@ from .unet_2d_blocks import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels):
incorrect_attention_head_dim_name = False
if (
"CrossAttnDownBlock2D" in down_block_types
or "CrossAttnUpBlock2D" in up_block_types
or mid_block_type == "UNetMidBlock2DCrossAttn"
):
incorrect_attention_head_dim_name = True
if incorrect_attention_head_dim_name:
num_attention_heads = attention_head_dim
else:
# we use attention_head_dim to calculate num_attention_heads
if isinstance(attention_head_dim, int):
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
else:
num_attention_heads = [
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
]
return num_attention_heads
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
@@ -247,21 +225,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.sample_size = sample_size
if attention_head_dim is not None:
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` instead."
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
num_attention_heads = correct_incorrect_names(
attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
logger.warning(
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
f"the model will be configured with `num_attention_heads` {num_attention_heads}."
)
attention_head_dim = None
if num_attention_heads is None:
raise ValueError("`num_attention_heads` cannot be None.")
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
@@ -282,6 +259,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
@@ -296,14 +278,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if isinstance(layer_number_per_block, list):
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
# make sure num_attention_heads is a tuple
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# we use num_attention_heads to calculate attention_head_dim
attention_head_dim = [
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
]
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
@@ -445,6 +419,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
@@ -492,7 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.down_blocks.append(down_block)
@@ -510,7 +490,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
@@ -526,7 +505,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
attention_head_dim=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
@@ -558,7 +536,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_attention_head_dim = list(reversed(attention_head_dim))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = (
@@ -607,7 +584,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=reversed_attention_head_dim[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
dropout=dropout,
)
self.up_blocks.append(up_block)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
@@ -163,6 +164,14 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
class AutoPipelineForText2Image(ConfigMixin):
r"""
@@ -382,7 +391,7 @@ class AutoPipelineForText2Image(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls)
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
@@ -659,7 +668,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls)
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
@@ -934,7 +943,7 @@ class AutoPipelineForInpainting(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)

View File

@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
return objs
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -1785,6 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return hidden_states, output_states
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class UpBlockFlat(nn.Module):
def __init__(
self,
@@ -1895,6 +1897,7 @@ class UpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
class CrossAttnUpBlockFlat(nn.Module):
def __init__(
self,
@@ -2068,6 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlat(nn.Module):
"""
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
@@ -2223,6 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
return hidden_states
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
self,
@@ -2369,6 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
return hidden_states
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
def __init__(
self,

View File

@@ -981,9 +981,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
custom_revision (`str`, *optional*):
custom_revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if youre downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -1422,7 +1423,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)
@@ -1472,7 +1472,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
hook.remove()
# make sure the model is in the same state as before calling it
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
self.enable_model_cpu_offload()
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
r"""
@@ -1508,7 +1508,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
device_type = torch_device.type
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
self._offload_device = device
if self.device.type != "cpu":
self.to("cpu", silence_dtype_warnings=True)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import DiffusionPipeline
from diffusers.utils.testing_utils import torch_device
class PEFTLoRALoading(unittest.TestCase):
def get_dummy_inputs(self):
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"generator": torch.manual_seed(0),
}
return pipeline_inputs
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
# This LoRA was obtained using similarly as how it's done in the training scripts.
# For details on how the LoRA was obtained, refer to:
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
inputs = self.get_dummy_inputs()
outputs = sd_pipe(**inputs).images
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
self.assertTrue(outputs.shape == (1, 64, 64, 3))
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
# This LoRA was obtained using similarly as how it's done in the training scripts.
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
inputs = self.get_dummy_inputs()
outputs = sd_pipe(**inputs).images
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
self.assertTrue(outputs.shape == (1, 64, 64, 3))
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)

View File

@@ -21,7 +21,6 @@ from collections import OrderedDict
from pathlib import Path
import torch
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from diffusers import (
AutoPipelineForImage2Image,
@@ -49,20 +48,6 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
class AutoPipelineFastTest(unittest.TestCase):
@property
def dummy_image_encoder(self):
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=1,
projection_dim=1,
num_hidden_layers=1,
num_attention_heads=1,
image_size=1,
intermediate_size=1,
patch_size=1,
)
return CLIPVisionModelWithProjection(config)
def test_from_pipe_consistent(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
@@ -219,20 +204,6 @@ class AutoPipelineFastTest(unittest.TestCase):
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
def test_from_pipe_optional_components(self):
image_encoder = self.dummy_image_encoder
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe",
image_encoder=image_encoder,
)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert pipe.image_encoder is not None
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
assert pipe.image_encoder is None
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):

View File

@@ -36,10 +36,10 @@ from diffusers import (
LMSDiscreteScheduler,
UniPCMultistepScheduler,
VQDiffusionScheduler,
logging,
)
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import logging
from diffusers.utils.testing_utils import CaptureLogger, torch_device
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -48,9 +48,6 @@ from ..others.test_utils import TOKEN, USER, is_staging_test
torch.backends.cuda.matmul.allow_tf32 = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SchedulerObject(SchedulerMixin, ConfigMixin):
config_name = "config.json"
@@ -256,60 +253,6 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_classes = ()
forward_default_kwargs = ()
@property
def default_num_inference_steps(self):
return 50
@property
def default_timestep(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
timestep = scheduler.timesteps[0]
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep` will be set to the default value of 1."
)
timestep = 1
return timestep
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
@property
def default_timestep_2(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
try:
scheduler_config = self.get_scheduler_config()
scheduler = self.scheduler_classes[0](**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
if len(scheduler.timesteps) >= 2:
timestep_2 = scheduler.timesteps[1]
else:
logger.warning(
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
f" will be used."
)
timestep_2 = 0
except NotImplementedError:
logger.warning(
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
f" `default_timestep_2` will be set to the default value of 0."
)
timestep_2 = 0
return timestep_2
@property
def dummy_sample(self):
batch_size = 4
@@ -370,7 +313,6 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes:
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
@@ -429,7 +371,6 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
time_step = time_step if time_step is not None else self.default_timestep
for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
@@ -470,10 +411,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
timestep = self.default_timestep
timestep = 1
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
@@ -556,10 +497,10 @@ class SchedulerCommonTest(unittest.TestCase):
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
num_inference_steps = kwargs.pop("num_inference_steps", None)
timestep_0 = self.default_timestep
timestep_1 = self.default_timestep_2
timestep_0 = 1
timestep_1 = 0
for scheduler_class in self.scheduler_classes:
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
@@ -617,9 +558,9 @@ class SchedulerCommonTest(unittest.TestCase):
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
num_inference_steps = kwargs.pop("num_inference_steps", 50)
timestep = self.default_timestep
timestep = 0
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
timestep = 1
@@ -703,7 +644,7 @@ class SchedulerCommonTest(unittest.TestCase):
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.default_num_inference_steps)
scheduler.set_timesteps(100)
sample = self.dummy_sample.to(torch_device)
if scheduler_class == CMStochasticIterativeScheduler: