mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 14:34:55 +08:00
Compare commits
2 Commits
rename-att
...
post_relea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95de8000ec | ||
|
|
2dfc2e8c47 |
13
.github/workflows/pr_tests.yml
vendored
13
.github/workflows/pr_tests.yml
vendored
@@ -34,6 +34,11 @@ jobs:
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: LoRA
|
||||
framework: lora
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_lora
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: docker-cpu
|
||||
@@ -89,6 +94,14 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests
|
||||
if: ${{ matrix.config.framework == 'lora' }}
|
||||
run: |
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
|
||||
@@ -26,9 +26,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
invisible_watermark && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -56,60 +56,6 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a local file
|
||||
|
||||
Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a `pipeline.py` file that contains the pipeline class in order to successfully load it.
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="./path/to/pipeline_directory/",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a specific version
|
||||
|
||||
By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
|
||||
|
||||
<hfoptions id="version">
|
||||
<hfoption id="main">
|
||||
|
||||
For example, to load from the `main` branch:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="main",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="older version">
|
||||
|
||||
For example, to load from a previous version of Diffusers like `v0.25.0`:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="v0.25.0",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
|
||||
|
||||
## Community components
|
||||
|
||||
@@ -42,7 +42,6 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.27.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Marigold monocular depth prediction pipeline.
|
||||
|
||||
@@ -376,14 +376,18 @@ After training, LoRA weights can be loaded very easily into the original pipelin
|
||||
load the original pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("base-model-name").to("cuda")
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
Next, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).
|
||||
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("path-to-the-lora-checkpoint")
|
||||
pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
|
||||
```
|
||||
|
||||
Finally, we can run the model in inference.
|
||||
|
||||
@@ -49,7 +49,6 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -196,7 +195,7 @@ def import_model_class_from_model_name_or_path(
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
@@ -210,25 +209,27 @@ def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = Non
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- t2iadapter
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# t2iadapter-{repo_id}
|
||||
|
||||
These are t2iadapter weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
|
||||
@@ -45,7 +45,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -67,8 +66,8 @@ DATASET_NAME_MAPPING = {
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images: list = None,
|
||||
repo_folder: str = None,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
@@ -76,7 +75,21 @@ def save_model_card(
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
model_description = f"""
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {args.pretrained_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
@@ -119,21 +132,10 @@ These are the key hyperparameters used during training:
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_description += wandb_info
|
||||
model_card += wandb_info
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
|
||||
|
||||
@@ -45,7 +45,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDif
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -56,39 +55,32 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None
|
||||
):
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion",
|
||||
"stable-diffusion-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"lora",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
@@ -58,7 +58,6 @@ from diffusers.utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -71,20 +70,33 @@ logger = get_logger(__name__)
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
train_text_encoder: bool = False,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
images=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
train_text_encoder=False,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = ""
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
dataset: {dataset_name}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
@@ -94,19 +106,8 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
|
||||
@@ -48,7 +48,6 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, U
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -66,45 +65,41 @@ DATASET_NAME_MAPPING = {
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images: list = None,
|
||||
validation_prompt: str = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
images=None,
|
||||
validation_prompt=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
model_description = f"""
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
dataset: {dataset_name}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
|
||||
This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
|
||||
{img_str}
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = [
|
||||
"stable-diffusion-xl",
|
||||
"stable-diffusion-xl-diffusers",
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
|
||||
@@ -167,10 +167,7 @@ vae_conversion_map_attn = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
@@ -324,18 +321,11 @@ if __name__ == "__main__":
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert text encoder 1
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Convert text encoder 2
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
# We call the `.T.contiguous()` to match what's done in
|
||||
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
|
||||
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
|
||||
"conditioner.embedders.1.model.text_projection.weight"
|
||||
).T.contiguous()
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
@@ -170,10 +170,7 @@ vae_extra_conversion_map = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
4
setup.py
4
setup.py
@@ -126,8 +126,8 @@ _deps = [
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"torch>=1.4,<2.2.0",
|
||||
"torchvision<0.17",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
]
|
||||
|
||||
@@ -38,8 +38,8 @@ deps = {
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"torch": "torch>=1.4,<2.2.0",
|
||||
"torchvision": "torchvision<0.17",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
}
|
||||
|
||||
@@ -166,7 +166,8 @@ class IPAdapterMixin:
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=Path(subfolder, "image_encoder").as_posix(),
|
||||
).to(self.device, dtype=self.dtype)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
self.image_encoder = image_encoder
|
||||
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
|
||||
else:
|
||||
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
@@ -25,7 +26,7 @@ from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .. import __version__
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
@@ -33,6 +34,7 @@ from ..utils import (
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
@@ -49,9 +51,10 @@ from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -103,9 +106,6 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
@@ -397,17 +397,16 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith("unet.unet") for key in keys):
|
||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
@@ -428,7 +427,9 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
@@ -517,11 +518,6 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -543,21 +539,34 @@ class LoraLoaderMixin:
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
if USE_PEFT_BACKEND:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
else:
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
|
||||
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
@@ -567,25 +576,84 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
lora_scale,
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
|
||||
for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(
|
||||
text_encoder_lora_state_dict, strict=False
|
||||
)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@@ -621,8 +689,6 @@ class LoraLoaderMixin:
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
@@ -639,6 +705,8 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
@@ -686,20 +754,118 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
remove_method = recurse_remove_peft_layers
|
||||
if USE_PEFT_BACKEND:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
|
||||
# In case text encoder have no Lora attached
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
if USE_PEFT_BACKEND:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_linear_layer = None
|
||||
attn_module.k_proj.lora_linear_layer = None
|
||||
attn_module.v_proj.lora_linear_layer = None
|
||||
attn_module.out_proj.lora_linear_layer = None
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_linear_layer = None
|
||||
mlp_module.fc2.lora_linear_layer = None
|
||||
|
||||
@classmethod
|
||||
def _modify_text_encoder(
|
||||
cls,
|
||||
text_encoder,
|
||||
lora_scale=1,
|
||||
network_alphas=None,
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
||||
|
||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
||||
return model
|
||||
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
lora_parameters = []
|
||||
network_alphas = {} if network_alphas is None else network_alphas
|
||||
is_network_alphas_populated = len(network_alphas) > 0
|
||||
|
||||
for name, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
query_alpha = network_alphas.pop(name + ".to_q_lora.down.weight.alpha", None)
|
||||
key_alpha = network_alphas.pop(name + ".to_k_lora.down.weight.alpha", None)
|
||||
value_alpha = network_alphas.pop(name + ".to_v_lora.down.weight.alpha", None)
|
||||
out_alpha = network_alphas.pop(name + ".to_out_lora.down.weight.alpha", None)
|
||||
|
||||
if isinstance(rank, dict):
|
||||
current_rank = rank.pop(f"{name}.out_proj.lora_linear_layer.up.weight")
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
attn_module.q_proj = create_patched_linear_lora(
|
||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.k_proj = create_patched_linear_lora(
|
||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.v_proj = create_patched_linear_lora(
|
||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.out_proj = create_patched_linear_lora(
|
||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
fc1_alpha = network_alphas.pop(name + ".fc1.lora_linear_layer.down.weight.alpha", None)
|
||||
fc2_alpha = network_alphas.pop(name + ".fc2.lora_linear_layer.down.weight.alpha", None)
|
||||
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
mlp_module.fc1 = create_patched_linear_lora(
|
||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
||||
)
|
||||
mlp_module.fc2 = create_patched_linear_lora(
|
||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
||||
)
|
||||
|
||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
|
||||
return lora_parameters
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
@@ -873,8 +1039,6 @@ class LoraLoaderMixin:
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
if self.num_fused_loras > 1:
|
||||
@@ -886,26 +1050,52 @@ class LoraLoaderMixin:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -930,18 +1120,40 @@ class LoraLoaderMixin:
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if unfuse_unet:
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
if not USE_PEFT_BACKEND:
|
||||
unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._unfuse_lora()
|
||||
attn_module.k_proj._unfuse_lora()
|
||||
attn_module.v_proj._unfuse_lora()
|
||||
attn_module.out_proj._unfuse_lora()
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._unfuse_lora()
|
||||
mlp_module.fc2._unfuse_lora()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
@@ -1222,9 +1434,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
@@ -1329,13 +1538,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
if USE_PEFT_BACKEND:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -1112,6 +1112,7 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
|
||||
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
|
||||
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
|
||||
|
||||
else:
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -510,15 +509,6 @@ class Attention(nn.Module):
|
||||
# The `Attention` class can call different attention processors / attention functions
|
||||
# here we simply pass along all tensors to the selected processor class
|
||||
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -43,24 +43,6 @@ from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if "CrossAttnDownBlock2D" in down_block_types or mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
@@ -240,22 +222,15 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads`."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, block_out_channels
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f" the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
@@ -270,13 +245,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
@@ -386,6 +354,12 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
@@ -411,7 +385,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@@ -448,7 +422,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import logging
|
||||
from ..utils.import_utils import is_transformers_available
|
||||
|
||||
|
||||
@@ -82,9 +82,6 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
class PatchedLoraProjection(torch.nn.Module):
|
||||
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
|
||||
deprecation_message = "Use of `PatchedLoraProjection` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("PatchedLoraProjection", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__()
|
||||
from ..models.lora import LoRALinearLayer
|
||||
|
||||
@@ -296,16 +293,10 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRAConv2dLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleConv` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
@@ -380,15 +371,10 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
"""
|
||||
|
||||
def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
|
||||
deprecation_message = "Use of `LoRACompatibleLinear` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
deprecation_message = "Use of `set_lora_layer()` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
|
||||
deprecate("set_lora_layer", "1.0.0", deprecation_message)
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
|
||||
@@ -119,7 +119,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
@@ -141,7 +140,6 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -163,7 +161,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -194,7 +191,6 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -222,7 +218,6 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -248,7 +243,6 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
add_self_attention=True if not add_downsample else False,
|
||||
)
|
||||
@@ -341,7 +335,6 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -365,7 +358,6 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -390,7 +382,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
@@ -421,7 +412,6 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -450,7 +440,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
temb_channels=temb_channels,
|
||||
@@ -479,7 +468,6 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
|
||||
@@ -567,7 +555,6 @@ class UNetMidBlock2D(nn.Module):
|
||||
attn_groups: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
):
|
||||
@@ -615,15 +602,13 @@ class UNetMidBlock2D(nn.Module):
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=in_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -695,7 +680,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
@@ -709,9 +693,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = in_channels // num_attention_heads
|
||||
|
||||
# support for variable transformer layers per block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
@@ -737,8 +718,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -751,8 +732,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -843,7 +824,6 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
@@ -858,9 +838,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = in_channels // self.attention_head_dim
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
@@ -971,7 +949,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -988,9 +965,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -1010,7 +984,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1100,7 +1074,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -1117,9 +1090,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -1142,8 +1112,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1157,8 +1127,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1425,7 +1395,6 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
@@ -1441,9 +1410,6 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
if resnet_time_scale_shift == "spatial":
|
||||
@@ -1478,7 +1444,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1529,7 +1495,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_downsample: bool = True,
|
||||
@@ -1544,9 +1509,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
@@ -1567,7 +1529,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1827,7 +1789,6 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -1844,9 +1805,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -1874,7 +1833,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=self.num_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -2068,7 +2027,6 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
num_layers: int = 4,
|
||||
resnet_group_size: int = 32,
|
||||
add_downsample: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 64,
|
||||
add_self_attention: bool = False,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -2078,9 +2036,6 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -2104,9 +2059,9 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
dim=out_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
out_channels,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
@@ -2203,7 +2158,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
upsample_type: str = "conv",
|
||||
@@ -2220,9 +2174,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -2244,7 +2195,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2329,7 +2280,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2346,9 +2296,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -2373,8 +2320,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2388,8 +2335,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2687,7 +2634,6 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2703,9 +2649,6 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
@@ -2742,7 +2685,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2794,7 +2737,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_upsample: bool = True,
|
||||
@@ -2829,13 +2771,10 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=num_attention_heads,
|
||||
heads=out_channels // attention_head_dim,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -3143,7 +3082,6 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -3159,9 +3097,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
@@ -3191,8 +3127,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
heads=self.num_heads,
|
||||
dim_head=self.attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
@@ -3398,7 +3334,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_eps: float = 1e-5,
|
||||
resnet_act_fn: str = "gelu",
|
||||
resnet_group_size: int = 32,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1, # attention dim_head
|
||||
cross_attention_dim: int = 768,
|
||||
add_upsample: bool = True,
|
||||
@@ -3415,11 +3350,6 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
if num_attention_heads is not None:
|
||||
logger.warn(
|
||||
"`num_attention_heads` argument is passed but ignored. The number of attention heads is determined by `attention_head_dim`, `in_channels` and `out_channels`."
|
||||
)
|
||||
|
||||
# in_channels, and out_channels for the block (k-unet)
|
||||
k_in_channels = out_channels if is_first_block else 2 * out_channels
|
||||
k_out_channels = in_channels
|
||||
@@ -3453,11 +3383,11 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
dim=k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
num_attention_heads=k_out_channels // attention_head_dim
|
||||
k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
k_out_channels // attention_head_dim
|
||||
if (i == num_layers - 1)
|
||||
else out_channels // attention_head_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
|
||||
@@ -55,28 +55,6 @@ from .unet_2d_blocks import (
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if (
|
||||
"CrossAttnDownBlock2D" in down_block_types
|
||||
or "CrossAttnUpBlock2D" in up_block_types
|
||||
or mid_block_type == "UNetMidBlock2DCrossAttn"
|
||||
):
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -247,21 +225,20 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` instead."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f"the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
@@ -282,6 +259,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
@@ -296,14 +278,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# make sure num_attention_heads is a tuple
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
@@ -445,6 +419,12 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
@@ -492,7 +472,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
@@ -510,7 +490,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -526,7 +505,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
@@ -558,7 +536,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = (
|
||||
@@ -607,7 +584,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
@@ -163,6 +164,14 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool
|
||||
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
|
||||
|
||||
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
class AutoPipelineForText2Image(ConfigMixin):
|
||||
r"""
|
||||
|
||||
@@ -382,7 +391,7 @@ class AutoPipelineForText2Image(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls)
|
||||
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
@@ -659,7 +668,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls)
|
||||
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
@@ -934,7 +943,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
)
|
||||
|
||||
# define expected module and optional kwargs given the pipeline signature
|
||||
expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
|
||||
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
|
||||
|
||||
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
||||
|
||||
|
||||
@@ -268,6 +268,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
@@ -1785,6 +1786,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class UpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1895,6 +1897,7 @@ class UpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class CrossAttnUpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2068,6 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
@@ -2223,6 +2227,7 @@ class UNetMidBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2369,6 +2374,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -981,9 +981,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*):
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
@@ -1422,7 +1423,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
@@ -1472,7 +1472,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
hook.remove()
|
||||
|
||||
# make sure the model is in the same state as before calling it
|
||||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
||||
self.enable_model_cpu_offload()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
@@ -1508,7 +1508,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
|
||||
2193
tests/lora/test_lora_layers_old_backend.py
Normal file
2193
tests/lora/test_lora_layers_old_backend.py
Normal file
File diff suppressed because it is too large
Load Diff
64
tests/lora/test_peft_lora_in_non_peft.py
Normal file
64
tests/lora/test_peft_lora_in_non_peft.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
class PEFTLoRALoading(unittest.TestCase):
|
||||
def get_dummy_inputs(self):
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
"generator": torch.manual_seed(0),
|
||||
}
|
||||
return pipeline_inputs
|
||||
|
||||
def test_stable_diffusion_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sd-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
# For details on how the LoRA was obtained, refer to:
|
||||
# https://hf.co/datasets/diffusers/notebooks/blob/main/check_logits_with_serialization_peft_lora.py
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sd-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.5396, 0.5707, 0.477, 0.4665, 0.5419, 0.4594, 0.4857, 0.4741, 0.4804])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_stable_diffusion_xl_peft_lora_loading_in_non_peft(self):
|
||||
sd_pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-sdxl-pipe").to(torch_device)
|
||||
# This LoRA was obtained using similarly as how it's done in the training scripts.
|
||||
sd_pipe.load_lora_weights("hf-internal-testing/tiny-sdxl-lora-peft")
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
outputs = sd_pipe(**inputs).images
|
||||
|
||||
predicted_slice = outputs[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.613, 0.5566, 0.54, 0.4162, 0.4042, 0.4596, 0.5374, 0.5286, 0.5038])
|
||||
|
||||
self.assertTrue(outputs.shape == (1, 64, 64, 3))
|
||||
assert np.allclose(expected_slice, predicted_slice, atol=1e-3, rtol=1e-3)
|
||||
@@ -21,7 +21,6 @@ from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
@@ -49,20 +48,6 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
|
||||
|
||||
|
||||
class AutoPipelineFastTest(unittest.TestCase):
|
||||
@property
|
||||
def dummy_image_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPVisionConfig(
|
||||
hidden_size=1,
|
||||
projection_dim=1,
|
||||
num_hidden_layers=1,
|
||||
num_attention_heads=1,
|
||||
image_size=1,
|
||||
intermediate_size=1,
|
||||
patch_size=1,
|
||||
)
|
||||
return CLIPVisionModelWithProjection(config)
|
||||
|
||||
def test_from_pipe_consistent(self):
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
|
||||
@@ -219,20 +204,6 @@ class AutoPipelineFastTest(unittest.TestCase):
|
||||
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
|
||||
assert "controlnet" in pipe_control_img2img.components
|
||||
|
||||
def test_from_pipe_optional_components(self):
|
||||
image_encoder = self.dummy_image_encoder
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-pipe",
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
|
||||
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
|
||||
assert pipe.image_encoder is not None
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
|
||||
assert pipe.image_encoder is None
|
||||
|
||||
|
||||
@slow
|
||||
class AutoPipelineIntegrationTest(unittest.TestCase):
|
||||
|
||||
@@ -36,10 +36,10 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.testing_utils import CaptureLogger, torch_device
|
||||
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
@@ -48,9 +48,6 @@ from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SchedulerObject(SchedulerMixin, ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@@ -256,60 +253,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_classes = ()
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def default_num_inference_steps(self):
|
||||
return 50
|
||||
|
||||
@property
|
||||
def default_timestep(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
try:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = self.scheduler_classes[0](**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
timestep = scheduler.timesteps[0]
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
|
||||
f" `default_timestep` will be set to the default value of 1."
|
||||
)
|
||||
timestep = 1
|
||||
|
||||
return timestep
|
||||
|
||||
# NOTE: currently taking the convention that default_timestep > default_timestep_2 (alternatively,
|
||||
# default_timestep comes earlier in the timestep schedule than default_timestep_2)
|
||||
@property
|
||||
def default_timestep_2(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.get("num_inference_steps", self.default_num_inference_steps)
|
||||
|
||||
try:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = self.scheduler_classes[0](**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
if len(scheduler.timesteps) >= 2:
|
||||
timestep_2 = scheduler.timesteps[1]
|
||||
else:
|
||||
logger.warning(
|
||||
f"Using num_inference_steps from the scheduler testing class's default config leads to a timestep"
|
||||
f" scheduler of length {len(scheduler.timesteps)} < 2. The default `default_timestep_2` value of 0"
|
||||
f" will be used."
|
||||
)
|
||||
timestep_2 = 0
|
||||
except NotImplementedError:
|
||||
logger.warning(
|
||||
f"The scheduler {self.__class__.__name__} does not implement a `get_scheduler_config` method."
|
||||
f" `default_timestep_2` will be set to the default value of 0."
|
||||
)
|
||||
timestep_2 = 0
|
||||
|
||||
return timestep_2
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
@@ -370,7 +313,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
time_step = time_step if time_step is not None else self.default_timestep
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
|
||||
@@ -429,7 +371,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
time_step = time_step if time_step is not None else self.default_timestep
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
@@ -470,10 +411,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_from_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
timestep = self.default_timestep
|
||||
timestep = 1
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
|
||||
@@ -556,10 +497,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
timestep_0 = self.default_timestep
|
||||
timestep_1 = self.default_timestep_2
|
||||
timestep_0 = 1
|
||||
timestep_1 = 0
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
@@ -617,9 +558,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", self.default_num_inference_steps)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", 50)
|
||||
|
||||
timestep = self.default_timestep
|
||||
timestep = 0
|
||||
if len(self.scheduler_classes) > 0 and self.scheduler_classes[0] == IPNDMScheduler:
|
||||
timestep = 1
|
||||
|
||||
@@ -703,7 +644,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
continue
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(self.default_num_inference_steps)
|
||||
scheduler.set_timesteps(100)
|
||||
|
||||
sample = self.dummy_sample.to(torch_device)
|
||||
if scheduler_class == CMStochasticIterativeScheduler:
|
||||
|
||||
Reference in New Issue
Block a user