Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
327243fd95 update 2024-02-19 14:39:46 +00:00

View File

@@ -55,6 +55,42 @@ from .unet_2d_blocks import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _check_to_assign_to_num_attention_heads(attention_head_dim, down_block_types, mid_block_type, up_block_types):
# If a new config is passed with the correct value for `num_attention_heads` and `attention_head_dim` is None,
# we do not need to reassign anything
if attention_head_dim is None:
return False
elif (
"CrossAttnDownBlock2D" in down_block_types
or "CrossAttnDownBlock2D" in up_block_types
or mid_block_type == "UNetMidBlock2DCrossAttn"
):
return True
return False
def _set_attention_parameters(num_attention_heads, attention_head_dim, block_out_channels):
if num_attention_heads is None:
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(block_out_channels)
num_attention_heads = [
out_channel // attention_head_dim[i] for i, out_channel in enumerate(block_out_channels)
]
elif attention_head_dim is None:
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(block_out_channels)
attention_head_dim = [
out_channel // num_attention_heads[i] for i, out_channel in enumerate(block_out_channels)
]
return num_attention_heads, attention_head_dim
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
@@ -225,18 +261,22 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.sample_size = sample_size
if num_attention_heads is not None:
if (num_attention_heads is not None) and (attention_head_dim is not None):
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
"It is not possible to configure the UNet with both `num_attention_heads` and `attention_head_dim` at the same time. Please set only one of these values."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
_should_assign_num_attention_heads = _check_to_assign_to_num_attention_heads(
attention_head_dim, down_block_types, mid_block_type, up_block_types
)
if _should_assign_num_attention_heads:
num_attention_heads = attention_head_dim
attention_head_dim = None
logger.warning(
"`attention_head_dim` has been incorrectly configured for this model and will be reassigned to `num_attention_heads`"
"Further referance: https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131"
)
self.register_to_config(num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim)
# Check inputs
if len(down_block_types) != len(up_block_types):
@@ -254,16 +294,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
@@ -410,6 +440,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
# Compute both attention_head_dim and num_attention_heads for each block
num_attention_heads, attention_head_dim = _set_attention_parameters(
num_attention_heads, attention_head_dim, block_out_channels
)
if isinstance(only_cross_attention, bool):
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
@@ -419,12 +454,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)