mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
17 Commits
pixart-tes
...
rename-att
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
88c704684c | ||
|
|
a45fed728c | ||
|
|
754b0532d7 | ||
|
|
80251ed035 | ||
|
|
ef8c0bf51d | ||
|
|
04be74ed94 | ||
|
|
e4bee5d8df | ||
|
|
9b1ff58b40 | ||
|
|
e7696e20f9 | ||
|
|
4b89aeffe1 | ||
|
|
0a1daadef8 | ||
|
|
371f765908 | ||
|
|
75aee39eac | ||
|
|
215e6804d3 | ||
|
|
9254d1f39a | ||
|
|
e1bdcc7af3 | ||
|
|
84905ca728 |
@@ -26,9 +26,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
--extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -25,9 +25,9 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.1.2 \
|
||||
torchvision==0.16.2 \
|
||||
torchaudio==2.1.2 \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
|
||||
@@ -56,6 +56,60 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a local file
|
||||
|
||||
Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a `pipeline.py` file that contains the pipeline class in order to successfully load it.
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="./path/to/pipeline_directory/",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a specific version
|
||||
|
||||
By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
|
||||
|
||||
<hfoptions id="version">
|
||||
<hfoption id="main">
|
||||
|
||||
For example, to load from the `main` branch:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="main",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="older version">
|
||||
|
||||
For example, to load from a previous version of Diffusers like `v0.25.0`:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="v0.25.0",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
|
||||
|
||||
## Community components
|
||||
|
||||
@@ -376,18 +376,14 @@ After training, LoRA weights can be loaded very easily into the original pipelin
|
||||
load the original pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
from diffusers import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("base-model-name").to("cuda")
|
||||
```
|
||||
|
||||
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
|
||||
Next, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).
|
||||
|
||||
```python
|
||||
pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
|
||||
pipe.load_lora_weights("path-to-the-lora-checkpoint")
|
||||
```
|
||||
|
||||
Finally, we can run the model in inference.
|
||||
|
||||
@@ -49,6 +49,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -195,7 +196,7 @@ def import_model_class_from_model_name_or_path(
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
def save_model_card(repo_id: str, image_logs: dict = None, base_model: str = None, repo_folder: str = None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
@@ -209,27 +210,25 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- t2iadapter
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# t2iadapter-{repo_id}
|
||||
|
||||
These are t2iadapter weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "t2iadapter"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
|
||||
@@ -67,8 +67,8 @@ DATASET_NAME_MAPPING = {
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
images: list = None,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
|
||||
@@ -56,7 +56,9 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
||||
def save_model_card(
|
||||
repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
|
||||
@@ -58,6 +58,7 @@ from diffusers.utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -70,33 +71,20 @@ logger = get_logger(__name__)
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
train_text_encoder=False,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
images: list = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
train_text_encoder: bool = False,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
dataset: {dataset_name}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
model_description = f"""
|
||||
# LoRA text2image fine-tuning - {repo_id}
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
|
||||
@@ -106,8 +94,19 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=base_model,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion-xl", "stable-diffusion-xl-diffusers", "text-to-image", "diffusers", "lora"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
|
||||
@@ -66,12 +66,12 @@ DATASET_NAME_MAPPING = {
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
validation_prompt=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
images: list = None,
|
||||
validation_prompt: str = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
|
||||
@@ -167,7 +167,10 @@ vae_conversion_map_attn = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
@@ -321,11 +324,18 @@ if __name__ == "__main__":
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert text encoder 1
|
||||
text_enc_dict = convert_openai_text_enc_state_dict(text_enc_dict)
|
||||
text_enc_dict = {"conditioner.embedders.0.transformer." + k: v for k, v in text_enc_dict.items()}
|
||||
|
||||
# Convert text encoder 2
|
||||
text_enc_2_dict = convert_openclip_text_enc_state_dict(text_enc_2_dict)
|
||||
text_enc_2_dict = {"conditioner.embedders.1.model." + k: v for k, v in text_enc_2_dict.items()}
|
||||
# We call the `.T.contiguous()` to match what's done in
|
||||
# https://github.com/huggingface/diffusers/blob/84905ca7287876b925b6bf8e9bb92fec21c78764/src/diffusers/loaders/single_file_utils.py#L1085
|
||||
text_enc_2_dict["conditioner.embedders.1.model.text_projection"] = text_enc_2_dict.pop(
|
||||
"conditioner.embedders.1.model.text_projection.weight"
|
||||
).T.contiguous()
|
||||
|
||||
# Put together new checkpoint
|
||||
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict, **text_enc_2_dict}
|
||||
|
||||
@@ -170,7 +170,10 @@ vae_extra_conversion_map = [
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if not w.ndim == 1:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
else:
|
||||
return w
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
|
||||
4
setup.py
4
setup.py
@@ -126,8 +126,8 @@ _deps = [
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4,<2.2.0",
|
||||
"torchvision<0.17",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
]
|
||||
|
||||
@@ -38,8 +38,8 @@ deps = {
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4,<2.2.0",
|
||||
"torchvision": "torchvision<0.17",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
}
|
||||
|
||||
@@ -1112,7 +1112,6 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
text_model_dict[diffusers_key + ".q_proj.bias"] = weight_value[:text_proj_dim]
|
||||
text_model_dict[diffusers_key + ".k_proj.bias"] = weight_value[text_proj_dim : text_proj_dim * 2]
|
||||
text_model_dict[diffusers_key + ".v_proj.bias"] = weight_value[text_proj_dim * 2 :]
|
||||
|
||||
else:
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlNetMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import BaseOutput, deprecate, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
@@ -43,6 +43,24 @@ from .unets.unet_2d_condition import UNet2DConditionModel
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if "CrossAttnDownBlock2D" in down_block_types or mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetOutput(BaseOutput):
|
||||
"""
|
||||
@@ -222,15 +240,22 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads`."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, block_out_channels
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f" the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
|
||||
# Check inputs
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
@@ -245,6 +270,13 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
@@ -354,12 +386,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
@@ -385,7 +411,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@@ -422,6 +448,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
@@ -119,6 +119,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
downsample_type=downsample_type,
|
||||
@@ -140,6 +141,7 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -161,6 +163,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -191,6 +194,7 @@ def get_down_block(
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -218,6 +222,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -243,6 +248,7 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
add_self_attention=True if not add_downsample else False,
|
||||
)
|
||||
@@ -335,6 +341,7 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
@@ -358,6 +365,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
skip_time_act=resnet_skip_time_act,
|
||||
@@ -382,6 +390,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
upsample_type=upsample_type,
|
||||
@@ -412,6 +421,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
@@ -440,6 +450,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
temb_channels=temb_channels,
|
||||
@@ -468,6 +479,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
|
||||
@@ -555,6 +567,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
attn_groups: Optional[int] = None,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
):
|
||||
@@ -602,13 +615,15 @@ class UNetMidBlock2D(nn.Module):
|
||||
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
||||
)
|
||||
attention_head_dim = in_channels
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
|
||||
for _ in range(num_layers):
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
heads=in_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -680,6 +695,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
dual_cross_attention: bool = False,
|
||||
@@ -693,6 +709,9 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.num_attention_heads = num_attention_heads
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = in_channels // num_attention_heads
|
||||
|
||||
# support for variable transformer layers per block
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
@@ -718,8 +737,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -732,8 +751,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
in_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -824,6 +843,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
cross_attention_dim: int = 1280,
|
||||
@@ -838,7 +858,9 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
self.num_heads = in_channels // self.attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = in_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
@@ -949,6 +971,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -965,6 +988,9 @@ class AttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -984,7 +1010,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1074,6 +1100,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
downsample_padding: int = 1,
|
||||
@@ -1090,6 +1117,9 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -1112,8 +1142,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1127,8 +1157,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -1395,6 +1425,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_downsample: bool = True,
|
||||
@@ -1410,6 +1441,9 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
if resnet_time_scale_shift == "spatial":
|
||||
@@ -1444,7 +1478,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1495,6 +1529,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_downsample: bool = True,
|
||||
@@ -1509,6 +1544,9 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
@@ -1529,7 +1567,7 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -1789,6 +1827,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -1805,7 +1844,9 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
@@ -1833,7 +1874,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
@@ -2027,6 +2068,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
num_layers: int = 4,
|
||||
resnet_group_size: int = 32,
|
||||
add_downsample: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 64,
|
||||
add_self_attention: bool = False,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -2036,6 +2078,9 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -2059,9 +2104,9 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
out_channels,
|
||||
out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
dim=out_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
@@ -2158,6 +2203,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
upsample_type: str = "conv",
|
||||
@@ -2174,6 +2220,9 @@ class AttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -2195,7 +2244,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2280,6 +2329,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: int = 1,
|
||||
attention_head_dim: Optional[int] = None,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2296,6 +2346,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
if attention_head_dim is None:
|
||||
attention_head_dim = out_channels // num_attention_heads
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
|
||||
|
||||
@@ -2320,8 +2373,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
if not dual_cross_attention:
|
||||
attentions.append(
|
||||
Transformer2DModel(
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=transformer_layers_per_block[i],
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2335,8 +2388,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
attentions.append(
|
||||
DualTransformer2DModel(
|
||||
num_attention_heads,
|
||||
out_channels // num_attention_heads,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
@@ -2634,6 +2687,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = 1.0,
|
||||
add_upsample: bool = True,
|
||||
@@ -2649,6 +2703,9 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
for i in range(num_layers):
|
||||
input_channels = in_channels if i == 0 else out_channels
|
||||
|
||||
@@ -2685,7 +2742,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -2737,6 +2794,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
output_scale_factor: float = np.sqrt(2.0),
|
||||
add_upsample: bool = True,
|
||||
@@ -2771,10 +2829,13 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
)
|
||||
attention_head_dim = out_channels
|
||||
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
|
||||
self.attentions.append(
|
||||
Attention(
|
||||
out_channels,
|
||||
heads=out_channels // attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
@@ -3082,6 +3143,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1,
|
||||
cross_attention_dim: int = 1280,
|
||||
output_scale_factor: float = 1.0,
|
||||
@@ -3097,7 +3159,9 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
self.num_heads = out_channels // self.attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
num_attention_heads = out_channels // attention_head_dim
|
||||
self.num_heads = num_attention_heads
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
@@ -3127,8 +3191,8 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
Attention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=self.attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
@@ -3334,6 +3398,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
resnet_eps: float = 1e-5,
|
||||
resnet_act_fn: str = "gelu",
|
||||
resnet_group_size: int = 32,
|
||||
num_attention_heads: Optional[int] = None,
|
||||
attention_head_dim: int = 1, # attention dim_head
|
||||
cross_attention_dim: int = 768,
|
||||
add_upsample: bool = True,
|
||||
@@ -3350,6 +3415,11 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
self.has_cross_attention = True
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
if num_attention_heads is not None:
|
||||
logger.warn(
|
||||
"`num_attention_heads` argument is passed but ignored. The number of attention heads is determined by `attention_head_dim`, `in_channels` and `out_channels`."
|
||||
)
|
||||
|
||||
# in_channels, and out_channels for the block (k-unet)
|
||||
k_in_channels = out_channels if is_first_block else 2 * out_channels
|
||||
k_out_channels = in_channels
|
||||
@@ -3383,11 +3453,11 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
attentions.append(
|
||||
KAttentionBlock(
|
||||
k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
k_out_channels // attention_head_dim
|
||||
dim=k_out_channels if (i == num_layers - 1) else out_channels,
|
||||
num_attention_heads=k_out_channels // attention_head_dim
|
||||
if (i == num_layers - 1)
|
||||
else out_channels // attention_head_dim,
|
||||
attention_head_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
temb_channels=temb_channels,
|
||||
attention_bias=True,
|
||||
|
||||
@@ -55,6 +55,28 @@ from .unet_2d_blocks import (
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def correct_incorrect_names(attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels):
|
||||
incorrect_attention_head_dim_name = False
|
||||
if (
|
||||
"CrossAttnDownBlock2D" in down_block_types
|
||||
or "CrossAttnUpBlock2D" in up_block_types
|
||||
or mid_block_type == "UNetMidBlock2DCrossAttn"
|
||||
):
|
||||
incorrect_attention_head_dim_name = True
|
||||
|
||||
if incorrect_attention_head_dim_name:
|
||||
num_attention_heads = attention_head_dim
|
||||
else:
|
||||
# we use attention_head_dim to calculate num_attention_heads
|
||||
if isinstance(attention_head_dim, int):
|
||||
num_attention_heads = [out_channels // attention_head_dim for out_channels in block_out_channels]
|
||||
else:
|
||||
num_attention_heads = [
|
||||
out_channels // attn_dim for out_channels, attn_dim in zip(block_out_channels, attention_head_dim)
|
||||
]
|
||||
return num_attention_heads
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -225,20 +247,21 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
if attention_head_dim is not None:
|
||||
deprecation_message = " `attention_head_dim` is deprecated and will be removed in a future version. Use `num_attention_heads` instead."
|
||||
deprecate("attention_head_dim not None", "1.0.0", deprecation_message, standard_warn=False)
|
||||
num_attention_heads = correct_incorrect_names(
|
||||
attention_head_dim, down_block_types, mid_block_type, up_block_types, block_out_channels
|
||||
)
|
||||
logger.warning(
|
||||
f"corrected potentially incorrect arguments attention_head_dim {attention_head_dim}."
|
||||
f"the model will be configured with `num_attention_heads` {num_attention_heads}."
|
||||
)
|
||||
attention_head_dim = None
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
if num_attention_heads is None:
|
||||
raise ValueError("`num_attention_heads` cannot be None.")
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
@@ -259,11 +282,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
@@ -278,6 +296,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if isinstance(layer_number_per_block, list):
|
||||
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||
|
||||
# make sure num_attention_heads is a tuple
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# we use num_attention_heads to calculate attention_head_dim
|
||||
attention_head_dim = [
|
||||
out_channels // num_heads for out_channels, num_heads in zip(block_out_channels, num_attention_heads)
|
||||
]
|
||||
# input
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
@@ -419,12 +445,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
@@ -472,7 +492,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i],
|
||||
dropout=dropout,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
@@ -490,6 +510,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
@@ -505,6 +526,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim[-1],
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
attention_head_dim=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
@@ -536,6 +558,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
reversed_layers_per_block = list(reversed(layers_per_block))
|
||||
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
||||
reversed_transformer_layers_per_block = (
|
||||
@@ -584,7 +607,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
resnet_skip_time_act=resnet_skip_time_act,
|
||||
resnet_out_scale_factor=resnet_out_scale_factor,
|
||||
cross_attention_norm=cross_attention_norm,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=reversed_attention_head_dim[i],
|
||||
dropout=dropout,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
|
||||
@@ -268,7 +268,6 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
||||
return objs
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
||||
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
||||
@@ -1786,7 +1785,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UpBlock2D with UpBlock2D->UpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class UpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1897,7 +1895,6 @@ class UpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.CrossAttnUpBlock2D with CrossAttnUpBlock2D->CrossAttnUpBlockFlat, ResnetBlock2D->ResnetBlockFlat, Upsample2D->LinearMultiDim
|
||||
class CrossAttnUpBlockFlat(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2071,7 +2068,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2D with UNetMidBlock2D->UNetMidBlockFlat, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlat(nn.Module):
|
||||
"""
|
||||
A 2D UNet mid-block [`UNetMidBlockFlat`] with multiple residual blocks and optional attention blocks.
|
||||
@@ -2227,7 +2223,6 @@ class UNetMidBlockFlat(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DCrossAttn with UNetMidBlock2DCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2374,7 +2369,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatSimpleCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -981,10 +981,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
custom_revision (`str`, *optional*):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
@@ -1423,6 +1422,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
@@ -1472,7 +1472,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
hook.remove()
|
||||
|
||||
# make sure the model is in the same state as before calling it
|
||||
self.enable_model_cpu_offload()
|
||||
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
||||
r"""
|
||||
@@ -1508,6 +1508,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
device_type = torch_device.type
|
||||
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
||||
self._offload_device = device
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
|
||||
Reference in New Issue
Block a user