mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-13 07:54:45 +08:00
Compare commits
16 Commits
custom-cod
...
unet-refac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43f72b2940 | ||
|
|
36e52ef801 | ||
|
|
95fcd61904 | ||
|
|
8f824bf0ab | ||
|
|
89bcec9a0d | ||
|
|
f404b6926e | ||
|
|
69f4b8ff5a | ||
|
|
f0ec02350a | ||
|
|
9fdd6de30f | ||
|
|
aa3b85bdd6 | ||
|
|
0953fed52b | ||
|
|
bd375a8034 | ||
|
|
17105d973c | ||
|
|
32e04da6cf | ||
|
|
c1e812b8fd | ||
|
|
784f4e9646 |
1181
src/diffusers/models/unets/unet_stable_diffusion.py
Normal file
1181
src/diffusers/models/unets/unet_stable_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/diffusers/models/unets/unet_utils.py
Normal file
205
src/diffusers/models/unets/unet_utils.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..attention_processor import (
|
||||||
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
Attention,
|
||||||
|
AttentionProcessor,
|
||||||
|
AttnAddedKVProcessor,
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DConditionModelUtilsMixin:
|
||||||
|
def _check_config(
|
||||||
|
self,
|
||||||
|
down_block_types: Tuple[str],
|
||||||
|
up_block_types: Tuple[str],
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]],
|
||||||
|
block_out_channels: Tuple[int],
|
||||||
|
layers_per_block: [int, Tuple[int]],
|
||||||
|
cross_attention_dim: Union[int, Tuple[int]],
|
||||||
|
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
|
||||||
|
reverse_transformer_layers_per_block: bool,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
||||||
|
):
|
||||||
|
if len(down_block_types) != len(up_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
||||||
|
for layer_number_per_block in transformer_layers_per_block:
|
||||||
|
if isinstance(layer_number_per_block, list):
|
||||||
|
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attn_processors(self):
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnAddedKVProcessor()
|
||||||
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnProcessor()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_attn_processor(processor)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if hasattr(module, "gradient_checkpointing"):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def enable_freeu(self, s1, s2, b1, b2):
|
||||||
|
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||||
|
|
||||||
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||||
|
|
||||||
|
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||||
|
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s1 (`float`):
|
||||||
|
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
s2 (`float`):
|
||||||
|
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||||
|
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||||
|
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||||
|
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||||
|
"""
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
setattr(upsample_block, "s1", s1)
|
||||||
|
setattr(upsample_block, "s2", s2)
|
||||||
|
setattr(upsample_block, "b1", b1)
|
||||||
|
setattr(upsample_block, "b2", b2)
|
||||||
|
|
||||||
|
def disable_freeu(self):
|
||||||
|
"""Disables the FreeU mechanism."""
|
||||||
|
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
for k in freeu_keys:
|
||||||
|
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
||||||
|
setattr(upsample_block, k, None)
|
||||||
|
|
||||||
|
def fuse_qkv_projections(self):
|
||||||
|
"""
|
||||||
|
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
||||||
|
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
"""
|
||||||
|
self.original_attn_processors = None
|
||||||
|
|
||||||
|
for _, attn_processor in self.attn_processors.items():
|
||||||
|
if "Added" in str(attn_processor.__class__.__name__):
|
||||||
|
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||||
|
|
||||||
|
self.original_attn_processors = self.attn_processors
|
||||||
|
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, Attention):
|
||||||
|
module.fuse_projections(fuse=True)
|
||||||
|
|
||||||
|
def unfuse_qkv_projections(self):
|
||||||
|
"""Disables the fused QKV projection if enabled.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This API is 🧪 experimental.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.original_attn_processors is not None:
|
||||||
|
self.set_attn_processor(self.original_attn_processors)
|
||||||
Reference in New Issue
Block a user