mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
shared-var
...
unet-confi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
327243fd95 |
@@ -55,6 +55,42 @@ from .unet_2d_blocks import (
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _check_to_assign_to_num_attention_heads(attention_head_dim, down_block_types, mid_block_type, up_block_types):
|
||||
# If a new config is passed with the correct value for `num_attention_heads` and `attention_head_dim` is None,
|
||||
# we do not need to reassign anything
|
||||
if attention_head_dim is None:
|
||||
return False
|
||||
|
||||
elif (
|
||||
"CrossAttnDownBlock2D" in down_block_types
|
||||
or "CrossAttnDownBlock2D" in up_block_types
|
||||
or mid_block_type == "UNetMidBlock2DCrossAttn"
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _set_attention_parameters(num_attention_heads, attention_head_dim, block_out_channels):
|
||||
if num_attention_heads is None:
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(block_out_channels)
|
||||
|
||||
num_attention_heads = [
|
||||
out_channel // attention_head_dim[i] for i, out_channel in enumerate(block_out_channels)
|
||||
]
|
||||
|
||||
elif attention_head_dim is None:
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(block_out_channels)
|
||||
|
||||
attention_head_dim = [
|
||||
out_channel // num_attention_heads[i] for i, out_channel in enumerate(block_out_channels)
|
||||
]
|
||||
|
||||
return num_attention_heads, attention_head_dim
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
@@ -225,18 +261,22 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
|
||||
self.sample_size = sample_size
|
||||
|
||||
if num_attention_heads is not None:
|
||||
if (num_attention_heads is not None) and (attention_head_dim is not None):
|
||||
raise ValueError(
|
||||
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
||||
"It is not possible to configure the UNet with both `num_attention_heads` and `attention_head_dim` at the same time. Please set only one of these values."
|
||||
)
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
_should_assign_num_attention_heads = _check_to_assign_to_num_attention_heads(
|
||||
attention_head_dim, down_block_types, mid_block_type, up_block_types
|
||||
)
|
||||
if _should_assign_num_attention_heads:
|
||||
num_attention_heads = attention_head_dim
|
||||
attention_head_dim = None
|
||||
logger.warning(
|
||||
"`attention_head_dim` has been incorrectly configured for this model and will be reassigned to `num_attention_heads`"
|
||||
"Further referance: https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131"
|
||||
)
|
||||
self.register_to_config(num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
@@ -254,16 +294,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
||||
@@ -410,6 +440,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# Compute both attention_head_dim and num_attention_heads for each block
|
||||
num_attention_heads, attention_head_dim = _set_attention_parameters(
|
||||
num_attention_heads, attention_head_dim, block_out_channels
|
||||
)
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = only_cross_attention
|
||||
@@ -419,12 +454,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if mid_block_only_cross_attention is None:
|
||||
mid_block_only_cross_attention = False
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(cross_attention_dim, int):
|
||||
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user